Skip to content

Commit 78a6fd5

Browse files
mauvilsaawaelchli
andauthored
Example and documentation for LightningCLI linking model and data arguments (#7299)
Co-authored-by: Adrian Wälchli <[email protected]>
1 parent bf1394a commit 78a6fd5

File tree

4 files changed

+90
-122
lines changed

4 files changed

+90
-122
lines changed

docs/source/common/lightning_cli.rst

Lines changed: 46 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -14,18 +14,14 @@
1414
def __init__(
1515
self,
1616
encoder_layers: int = 12,
17-
decoder_layers: List[int] = [2, 4]
17+
decoder_layers: List[int] = [2, 4],
18+
batch_size: int = 8,
1819
):
19-
"""Example encoder-decoder model
20-
21-
Args:
22-
encoder_layers: Number of layers for the encoder
23-
decoder_layers: Number of layers for each decoder block
24-
"""
2520
pass
2621

2722
class MyDataModule(LightningDataModule):
28-
pass
23+
def __init__(self, batch_size: int = 8):
24+
pass
2925

3026
def send_email(address, message):
3127
pass
@@ -119,7 +115,7 @@ The start of a possible implementation of :class:`MyModel` including the recomme
119115
docstring could be the one below. Note that by using type hints and docstrings there is no need to duplicate this
120116
information to define its configurable arguments.
121117

122-
.. testcode::
118+
.. testcode:: mymodel
123119

124120
class MyModel(LightningModule):
125121

@@ -373,8 +369,46 @@ before and after the execution of fit. The code would be something like:
373369
cli = MyLightningCLI(MyModel)
374370

375371
Note that the config object :code:`self.config` is a dictionary whose keys are global options or groups of options. It
376-
has the same structure as the yaml format as described previously. This means for instance that the parameters used for
372+
has the same structure as the yaml format described previously. This means for instance that the parameters used for
377373
instantiating the trainer class can be found in :code:`self.config['trainer']`.
378374

379-
For more advanced use cases, other methods of the :class:`~pytorch_lightning.utilities.cli.LightningCLI` class could be
380-
extended. For further information have a look at the corresponding API reference.
375+
Another case in which it might be desired to extend :class:`~pytorch_lightning.utilities.cli.LightningCLI` is that the
376+
model and data module depend on a common parameter. For example in some cases both classes require to know the
377+
:code:`batch_size`. It is a burden and error prone giving the same value twice in a config file. To avoid this the
378+
parser can be configured so that a value is only given once and then propagated accordingly. With a tool implemented
379+
like shown below, the :code:`batch_size` only has to be provided in the :code:`data` section of the config.
380+
381+
.. testcode::
382+
383+
from pytorch_lightning.utilities.cli import LightningCLI
384+
385+
class MyLightningCLI(LightningCLI):
386+
387+
def add_arguments_to_parser(self, parser):
388+
parser.link_arguments('data.batch_size', 'model.batch_size')
389+
390+
cli = MyLightningCLI(MyModel, MyDataModule)
391+
392+
The linking of arguments is observed in the help of the tool, which for this example would look like:
393+
394+
.. code-block:: bash
395+
396+
$ python trainer.py --help
397+
...
398+
--data.batch_size BATCH_SIZE
399+
Number of samples in a batch (type: int, default: 8)
400+
401+
Linked arguments:
402+
model.batch_size <-- data.batch_size
403+
Number of samples in a batch (type: int)
404+
405+
.. tip::
406+
407+
The linking of arguments can be used for more complex cases. For example to derive a value via a function that takes
408+
multiple settings as input. For more details have a look at the API of `link_arguments
409+
<https://jsonargparse.readthedocs.io/en/stable/#jsonargparse.core.ArgumentParser.link_arguments>`_.
410+
411+
.. tip::
412+
413+
Have a look at the :class:`~pytorch_lightning.utilities.cli.LightningCLI` class API reference to learn about other
414+
methods that can be extended to customize a CLI.

pl_examples/domain_templates/computer_vision_fine_tuning.py

Lines changed: 39 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -37,9 +37,8 @@
3737
Note:
3838
See: https://pytorch.org/tutorials/beginner/transfer_learning_tutorial.html
3939
"""
40-
import argparse
40+
4141
import logging
42-
import os
4342
from pathlib import Path
4443
from typing import Union
4544

@@ -59,6 +58,7 @@
5958
from pytorch_lightning import LightningDataModule
6059
from pytorch_lightning.callbacks.finetuning import BaseFinetuning
6160
from pytorch_lightning.utilities import rank_zero_info
61+
from pytorch_lightning.utilities.cli import LightningCLI
6262

6363
log = logging.getLogger(__name__)
6464
DATA_URL = "https://storage.googleapis.com/mledu-datasets/cats_and_dogs_filtered.zip"
@@ -93,10 +93,17 @@ class CatDogImageDataModule(LightningDataModule):
9393

9494
def __init__(
9595
self,
96-
dl_path: Union[str, Path],
96+
dl_path: Union[str, Path] = "data",
9797
num_workers: int = 0,
9898
batch_size: int = 8,
9999
):
100+
"""CatDogImageDataModule
101+
102+
Args:
103+
dl_path: root directory where to download the data
104+
num_workers: number of CPU workers
105+
batch_size: number of sample in a batch
106+
"""
100107
super().__init__()
101108

102109
self._dl_path = dl_path
@@ -146,17 +153,6 @@ def val_dataloader(self):
146153
log.info("Validation data loaded.")
147154
return self.__dataloader(train=False)
148155

149-
@staticmethod
150-
def add_model_specific_args(parent_parser):
151-
parser = parent_parser.add_argument_group("CatDogImageDataModule")
152-
parser.add_argument(
153-
"--num-workers", default=0, type=int, metavar="W", help="number of CPU workers", dest="num_workers"
154-
)
155-
parser.add_argument(
156-
"--batch-size", default=8, type=int, metavar="W", help="number of sample in a batch", dest="batch_size"
157-
)
158-
return parent_parser
159-
160156

161157
# --- Pytorch-lightning module ---
162158

@@ -166,17 +162,22 @@ class TransferLearningModel(pl.LightningModule):
166162
def __init__(
167163
self,
168164
backbone: str = "resnet50",
169-
train_bn: bool = True,
170-
milestones: tuple = (5, 10),
165+
train_bn: bool = False,
166+
milestones: tuple = (2, 4),
171167
batch_size: int = 32,
172-
lr: float = 1e-2,
168+
lr: float = 1e-3,
173169
lr_scheduler_gamma: float = 1e-1,
174170
num_workers: int = 6,
175171
**kwargs,
176172
) -> None:
177-
"""
173+
"""TransferLearningModel
174+
178175
Args:
179-
dl_path: Path where the data will be downloaded
176+
backbone: Name (as in ``torchvision.models``) of the feature extractor
177+
train_bn: Whether the BatchNorm layers should be trainable
178+
milestones: List of two epochs milestones
179+
lr: Initial learning rate
180+
lr_scheduler_gamma: Factor by which the learning rate is reduced at each milestone
180181
"""
181182
super().__init__()
182183
self.backbone = backbone
@@ -269,90 +270,31 @@ def configure_optimizers(self):
269270
scheduler = MultiStepLR(optimizer, milestones=self.milestones, gamma=self.lr_scheduler_gamma)
270271
return [optimizer], [scheduler]
271272

272-
@staticmethod
273-
def add_model_specific_args(parent_parser):
274-
parser = parent_parser.add_argument_group("TransferLearningModel")
275-
parser.add_argument(
276-
"--backbone",
277-
default="resnet50",
278-
type=str,
279-
metavar="BK",
280-
help="Name (as in ``torchvision.models``) of the feature extractor",
281-
)
282-
parser.add_argument(
283-
"--epochs", default=15, type=int, metavar="N", help="total number of epochs", dest="nb_epochs"
284-
)
285-
parser.add_argument("--batch-size", default=8, type=int, metavar="B", help="batch size", dest="batch_size")
286-
parser.add_argument("--gpus", type=int, default=0, help="number of gpus to use")
287-
parser.add_argument(
288-
"--lr", "--learning-rate", default=1e-3, type=float, metavar="LR", help="initial learning rate", dest="lr"
289-
)
290-
parser.add_argument(
291-
"--lr-scheduler-gamma",
292-
default=1e-1,
293-
type=float,
294-
metavar="LRG",
295-
help="Factor by which the learning rate is reduced at each milestone",
296-
dest="lr_scheduler_gamma",
297-
)
298-
parser.add_argument(
299-
"--train-bn",
300-
default=False,
301-
type=bool,
302-
metavar="TB",
303-
help="Whether the BatchNorm layers should be trainable",
304-
dest="train_bn",
305-
)
306-
parser.add_argument(
307-
"--milestones", default=[2, 4], type=list, metavar="M", help="List of two epochs milestones"
308-
)
309-
return parent_parser
310-
311-
312-
def main(args: argparse.Namespace) -> None:
313-
"""Train the model.
314-
315-
Args:
316-
args: Model hyper-parameters
317-
318-
Note:
319-
For the sake of the example, the images dataset will be downloaded
320-
to a temporary directory.
321-
"""
322273

323-
datamodule = CatDogImageDataModule(
324-
dl_path=os.path.join(args.root_data_path, 'data'), batch_size=args.batch_size, num_workers=args.num_workers
325-
)
326-
model = TransferLearningModel(**vars(args))
327-
finetuning_callback = MilestonesFinetuning(milestones=args.milestones)
274+
class MyLightningCLI(LightningCLI):
328275

329-
trainer = pl.Trainer(
330-
weights_summary=None,
331-
progress_bar_refresh_rate=1,
332-
num_sanity_val_steps=0,
333-
gpus=args.gpus,
334-
max_epochs=args.nb_epochs,
335-
callbacks=[finetuning_callback]
336-
)
276+
def add_arguments_to_parser(self, parser):
277+
parser.add_class_arguments(MilestonesFinetuning, 'finetuning')
278+
parser.link_arguments('data.batch_size', 'model.batch_size')
279+
parser.link_arguments('finetuning.milestones', 'model.milestones')
280+
parser.link_arguments('finetuning.train_bn', 'model.train_bn')
281+
parser.set_defaults({
282+
'trainer.max_epochs': 15,
283+
'trainer.weights_summary': None,
284+
'trainer.progress_bar_refresh_rate': 1,
285+
'trainer.num_sanity_val_steps': 0,
286+
})
337287

338-
trainer.fit(model, datamodule=datamodule)
288+
def instantiate_trainer(self):
289+
finetuning_callback = MilestonesFinetuning(**self.config_init['finetuning'])
290+
self.trainer_defaults['callbacks'] = [finetuning_callback]
291+
super().instantiate_trainer()
339292

340293

341-
def get_args() -> argparse.Namespace:
342-
parent_parser = argparse.ArgumentParser(add_help=False)
343-
parent_parser.add_argument(
344-
"--root-data-path",
345-
metavar="DIR",
346-
type=str,
347-
default=Path.cwd().as_posix(),
348-
help="Root directory where to download the data",
349-
dest="root_data_path",
350-
)
351-
parser = TransferLearningModel.add_model_specific_args(parent_parser)
352-
parser = CatDogImageDataModule.add_argparse_args(parser)
353-
return parser.parse_args()
294+
def cli_main():
295+
MyLightningCLI(TransferLearningModel, CatDogImageDataModule, seed_everything_default=1234)
354296

355297

356298
if __name__ == "__main__":
357299
cli_lightning_logo()
358-
main(get_args())
300+
cli_main()

pytorch_lightning/utilities/cli.py

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -161,9 +161,8 @@ def __init__(
161161
self.parser_kwargs.update({'description': description, 'env_prefix': env_prefix, 'default_env': env_parse})
162162

163163
self.init_parser()
164-
self.add_arguments_to_parser(self.parser)
165164
self.add_core_arguments_to_parser()
166-
self.before_parse_arguments(self.parser)
165+
self.add_arguments_to_parser(self.parser)
167166
self.parse_arguments()
168167
if self.config['seed_everything'] is not None:
169168
seed_everything(self.config['seed_everything'])
@@ -178,13 +177,6 @@ def init_parser(self) -> None:
178177
"""Method that instantiates the argument parser"""
179178
self.parser = LightningArgumentParser(**self.parser_kwargs)
180179

181-
def add_arguments_to_parser(self, parser: LightningArgumentParser) -> None:
182-
"""Implement to add extra arguments to parser
183-
184-
Args:
185-
parser: The argument parser object to which arguments should be added
186-
"""
187-
188180
def add_core_arguments_to_parser(self) -> None:
189181
"""Adds arguments from the core classes to the parser"""
190182
self.parser.add_argument(
@@ -200,11 +192,11 @@ def add_core_arguments_to_parser(self) -> None:
200192
if self.datamodule_class is not None:
201193
self.parser.add_lightning_class_args(self.datamodule_class, 'data', subclass_mode=self.subclass_mode_data)
202194

203-
def before_parse_arguments(self, parser: LightningArgumentParser) -> None:
204-
"""Implement to run some code before parsing arguments
195+
def add_arguments_to_parser(self, parser: LightningArgumentParser) -> None:
196+
"""Implement to add extra arguments to parser or link arguments
205197
206198
Args:
207-
parser: The argument parser object that will be used to parse
199+
parser: The argument parser object to which arguments can be added
208200
"""
209201

210202
def parse_arguments(self) -> None:

requirements/extra.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,4 +7,4 @@ torchtext>=0.5
77
# onnx>=1.7.0
88
onnxruntime>=1.3.0
99
hydra-core>=1.0
10-
jsonargparse[signatures]>=3.11.0
10+
jsonargparse[signatures]>=3.11.1

0 commit comments

Comments
 (0)