37
37
Note:
38
38
See: https://pytorch.org/tutorials/beginner/transfer_learning_tutorial.html
39
39
"""
40
- import argparse
40
+
41
41
import logging
42
- import os
43
42
from pathlib import Path
44
43
from typing import Union
45
44
59
58
from pytorch_lightning import LightningDataModule
60
59
from pytorch_lightning .callbacks .finetuning import BaseFinetuning
61
60
from pytorch_lightning .utilities import rank_zero_info
61
+ from pytorch_lightning .utilities .cli import LightningCLI
62
62
63
63
log = logging .getLogger (__name__ )
64
64
DATA_URL = "https://storage.googleapis.com/mledu-datasets/cats_and_dogs_filtered.zip"
@@ -93,10 +93,17 @@ class CatDogImageDataModule(LightningDataModule):
93
93
94
94
def __init__ (
95
95
self ,
96
- dl_path : Union [str , Path ],
96
+ dl_path : Union [str , Path ] = "data" ,
97
97
num_workers : int = 0 ,
98
98
batch_size : int = 8 ,
99
99
):
100
+ """CatDogImageDataModule
101
+
102
+ Args:
103
+ dl_path: root directory where to download the data
104
+ num_workers: number of CPU workers
105
+ batch_size: number of sample in a batch
106
+ """
100
107
super ().__init__ ()
101
108
102
109
self ._dl_path = dl_path
@@ -146,17 +153,6 @@ def val_dataloader(self):
146
153
log .info ("Validation data loaded." )
147
154
return self .__dataloader (train = False )
148
155
149
- @staticmethod
150
- def add_model_specific_args (parent_parser ):
151
- parser = parent_parser .add_argument_group ("CatDogImageDataModule" )
152
- parser .add_argument (
153
- "--num-workers" , default = 0 , type = int , metavar = "W" , help = "number of CPU workers" , dest = "num_workers"
154
- )
155
- parser .add_argument (
156
- "--batch-size" , default = 8 , type = int , metavar = "W" , help = "number of sample in a batch" , dest = "batch_size"
157
- )
158
- return parent_parser
159
-
160
156
161
157
# --- Pytorch-lightning module ---
162
158
@@ -166,17 +162,22 @@ class TransferLearningModel(pl.LightningModule):
166
162
def __init__ (
167
163
self ,
168
164
backbone : str = "resnet50" ,
169
- train_bn : bool = True ,
170
- milestones : tuple = (5 , 10 ),
165
+ train_bn : bool = False ,
166
+ milestones : tuple = (2 , 4 ),
171
167
batch_size : int = 32 ,
172
- lr : float = 1e-2 ,
168
+ lr : float = 1e-3 ,
173
169
lr_scheduler_gamma : float = 1e-1 ,
174
170
num_workers : int = 6 ,
175
171
** kwargs ,
176
172
) -> None :
177
- """
173
+ """TransferLearningModel
174
+
178
175
Args:
179
- dl_path: Path where the data will be downloaded
176
+ backbone: Name (as in ``torchvision.models``) of the feature extractor
177
+ train_bn: Whether the BatchNorm layers should be trainable
178
+ milestones: List of two epochs milestones
179
+ lr: Initial learning rate
180
+ lr_scheduler_gamma: Factor by which the learning rate is reduced at each milestone
180
181
"""
181
182
super ().__init__ ()
182
183
self .backbone = backbone
@@ -269,90 +270,31 @@ def configure_optimizers(self):
269
270
scheduler = MultiStepLR (optimizer , milestones = self .milestones , gamma = self .lr_scheduler_gamma )
270
271
return [optimizer ], [scheduler ]
271
272
272
- @staticmethod
273
- def add_model_specific_args (parent_parser ):
274
- parser = parent_parser .add_argument_group ("TransferLearningModel" )
275
- parser .add_argument (
276
- "--backbone" ,
277
- default = "resnet50" ,
278
- type = str ,
279
- metavar = "BK" ,
280
- help = "Name (as in ``torchvision.models``) of the feature extractor" ,
281
- )
282
- parser .add_argument (
283
- "--epochs" , default = 15 , type = int , metavar = "N" , help = "total number of epochs" , dest = "nb_epochs"
284
- )
285
- parser .add_argument ("--batch-size" , default = 8 , type = int , metavar = "B" , help = "batch size" , dest = "batch_size" )
286
- parser .add_argument ("--gpus" , type = int , default = 0 , help = "number of gpus to use" )
287
- parser .add_argument (
288
- "--lr" , "--learning-rate" , default = 1e-3 , type = float , metavar = "LR" , help = "initial learning rate" , dest = "lr"
289
- )
290
- parser .add_argument (
291
- "--lr-scheduler-gamma" ,
292
- default = 1e-1 ,
293
- type = float ,
294
- metavar = "LRG" ,
295
- help = "Factor by which the learning rate is reduced at each milestone" ,
296
- dest = "lr_scheduler_gamma" ,
297
- )
298
- parser .add_argument (
299
- "--train-bn" ,
300
- default = False ,
301
- type = bool ,
302
- metavar = "TB" ,
303
- help = "Whether the BatchNorm layers should be trainable" ,
304
- dest = "train_bn" ,
305
- )
306
- parser .add_argument (
307
- "--milestones" , default = [2 , 4 ], type = list , metavar = "M" , help = "List of two epochs milestones"
308
- )
309
- return parent_parser
310
-
311
-
312
- def main (args : argparse .Namespace ) -> None :
313
- """Train the model.
314
-
315
- Args:
316
- args: Model hyper-parameters
317
-
318
- Note:
319
- For the sake of the example, the images dataset will be downloaded
320
- to a temporary directory.
321
- """
322
273
323
- datamodule = CatDogImageDataModule (
324
- dl_path = os .path .join (args .root_data_path , 'data' ), batch_size = args .batch_size , num_workers = args .num_workers
325
- )
326
- model = TransferLearningModel (** vars (args ))
327
- finetuning_callback = MilestonesFinetuning (milestones = args .milestones )
274
+ class MyLightningCLI (LightningCLI ):
328
275
329
- trainer = pl .Trainer (
330
- weights_summary = None ,
331
- progress_bar_refresh_rate = 1 ,
332
- num_sanity_val_steps = 0 ,
333
- gpus = args .gpus ,
334
- max_epochs = args .nb_epochs ,
335
- callbacks = [finetuning_callback ]
336
- )
276
+ def add_arguments_to_parser (self , parser ):
277
+ parser .add_class_arguments (MilestonesFinetuning , 'finetuning' )
278
+ parser .link_arguments ('data.batch_size' , 'model.batch_size' )
279
+ parser .link_arguments ('finetuning.milestones' , 'model.milestones' )
280
+ parser .link_arguments ('finetuning.train_bn' , 'model.train_bn' )
281
+ parser .set_defaults ({
282
+ 'trainer.max_epochs' : 15 ,
283
+ 'trainer.weights_summary' : None ,
284
+ 'trainer.progress_bar_refresh_rate' : 1 ,
285
+ 'trainer.num_sanity_val_steps' : 0 ,
286
+ })
337
287
338
- trainer .fit (model , datamodule = datamodule )
288
+ def instantiate_trainer (self ):
289
+ finetuning_callback = MilestonesFinetuning (** self .config_init ['finetuning' ])
290
+ self .trainer_defaults ['callbacks' ] = [finetuning_callback ]
291
+ super ().instantiate_trainer ()
339
292
340
293
341
- def get_args () -> argparse .Namespace :
342
- parent_parser = argparse .ArgumentParser (add_help = False )
343
- parent_parser .add_argument (
344
- "--root-data-path" ,
345
- metavar = "DIR" ,
346
- type = str ,
347
- default = Path .cwd ().as_posix (),
348
- help = "Root directory where to download the data" ,
349
- dest = "root_data_path" ,
350
- )
351
- parser = TransferLearningModel .add_model_specific_args (parent_parser )
352
- parser = CatDogImageDataModule .add_argparse_args (parser )
353
- return parser .parse_args ()
294
+ def cli_main ():
295
+ MyLightningCLI (TransferLearningModel , CatDogImageDataModule , seed_everything_default = 1234 )
354
296
355
297
356
298
if __name__ == "__main__" :
357
299
cli_lightning_logo ()
358
- main ( get_args () )
300
+ cli_main ( )
0 commit comments