Skip to content

Commit 78d45a1

Browse files
mauvilsacarmocca
andauthored
Improve LightningCLI documentation and tests (#7156)
* - Added cli unit tests for help, print_config and submodules. - Added to cli documentation use of subclass help and print_config, submodules and other minor improvements. - Increased minimum jsonargparse version required for new documented features. * Improvements to lightning_cli.rst * Add check for all trainer parameters in test_lightning_cli_help * Increased minimum jsonargparse version Co-authored-by: Carlos Mocholí <[email protected]>
1 parent d123aaa commit 78d45a1

File tree

3 files changed

+168
-6
lines changed

3 files changed

+168
-6
lines changed

docs/source/common/lightning_cli.rst

Lines changed: 68 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,9 @@
3333
MyModelBaseClass = MyModel
3434
MyDataModuleBaseClass = MyDataModule
3535

36+
EncoderBaseClass = MyModel
37+
DecoderBaseClass = MyModel
38+
3639
mock_argv = mock.patch("sys.argv", ["any.py"])
3740
mock_argv.start()
3841

@@ -116,7 +119,7 @@ The start of a possible implementation of :class:`MyModel` including the recomme
116119
docstring could be the one below. Note that by using type hints and docstrings there is no need to duplicate this
117120
information to define its configurable arguments.
118121

119-
.. code-block:: python
122+
.. testcode::
120123

121124
class MyModel(LightningModule):
122125

@@ -131,7 +134,8 @@ information to define its configurable arguments.
131134
encoder_layers: Number of layers for the encoder
132135
decoder_layers: Number of layers for each decoder block
133136
"""
134-
...
137+
super().__init__()
138+
self.save_hyperparameters()
135139

136140
With this model class, the help of the trainer tool would look as follows:
137141

@@ -258,7 +262,67 @@ A possible config file could be as follows:
258262
...
259263
260264
Only model classes that are a subclass of :code:`MyModelBaseClass` would be allowed, and similarly only subclasses of
261-
:code:`MyDataModuleBaseClass`.
265+
:code:`MyDataModuleBaseClass`. If as base classes :class:`~pytorch_lightning.core.lightning.LightningModule` and
266+
:class:`~pytorch_lightning.core.datamodule.LightningDataModule` are given, then the tool would allow any lightning
267+
module and data module.
268+
269+
.. tip::
270+
271+
Note that with the subclass modes the :code:`--help` option does not show information for a specific subclass. To
272+
get help for a subclass the options :code:`--model.help` and :code:`--data.help` can be used, followed by the
273+
desired class path. Similarly :code:`--print_config` does not include the settings for a particular subclass. To
274+
include them the class path should be given before the :code:`--print_config` option. Examples for both help and
275+
print config are:
276+
277+
.. code-block:: bash
278+
279+
$ python trainer.py --model.help mycode.mymodels.MyModel
280+
$ python trainer.py --model mycode.mymodels.MyModel --print_config
281+
282+
283+
Models with multiple submodules
284+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
285+
286+
Many use cases require to have several modules each with its own configurable options. One possible way to handle this
287+
with LightningCLI is to implement a single module having as init parameters each of the submodules. Since the init
288+
parameters have as type a class, then in the configuration these would be specified with :code:`class_path` and
289+
:code:`init_args` entries. For instance a model could be implemented as:
290+
291+
.. testcode::
292+
293+
class MyMainModel(LightningModule):
294+
295+
def __init__(
296+
self,
297+
encoder: EncoderBaseClass,
298+
decoder: DecoderBaseClass
299+
):
300+
"""Example encoder-decoder submodules model
301+
302+
Args:
303+
encoder: Instance of a module for encoding
304+
decoder: Instance of a module for decoding
305+
"""
306+
super().__init__()
307+
self.encoder = encoder
308+
self.decoder = decoder
309+
310+
If the CLI is implemented as :code:`LightningCLI(MyMainModel)` the configuration would be as follows:
311+
312+
.. code-block:: yaml
313+
314+
model:
315+
encoder:
316+
class_path: mycode.myencoders.MyEncoder
317+
init_args:
318+
...
319+
decoder:
320+
class_path: mycode.mydecoders.MyDecoder
321+
init_args:
322+
...
323+
324+
It is also possible to combine :code:`subclass_mode_model=True` and submodules, thereby having two levels of
325+
:code:`class_path`.
262326

263327

264328
Customizing LightningCLI
@@ -275,7 +339,7 @@ extended to customize different parts of the command line tool. The argument par
275339
adding arguments can be done using the :func:`add_argument` method. In contrast to argparse it has additional methods to
276340
add arguments, for example :func:`add_class_arguments` adds all arguments from the init of a class, though requiring
277341
parameters to have type hints. For more details about this please refer to the `respective documentation
278-
<https://omni-us.github.io/jsonargparse/#classes-methods-and-functions>`_.
342+
<https://jsonargparse.readthedocs.io/en/stable/#classes-methods-and-functions>`_.
279343

280344
The :class:`~pytorch_lightning.utilities.cli.LightningCLI` class has the
281345
:meth:`~pytorch_lightning.utilities.cli.LightningCLI.add_arguments_to_parser` method which can be implemented to include

requirements/extra.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,4 +7,4 @@ torchtext>=0.5
77
# onnx>=1.7.0
88
onnxruntime>=1.3.0
99
hydra-core>=1.0
10-
jsonargparse[signatures]>=3.9.0
10+
jsonargparse[signatures]>=3.10.1

tests/utilities/test_cli.py

Lines changed: 99 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,17 +12,20 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import inspect
1516
import json
1617
import os
1718
import pickle
1819
import sys
1920
from argparse import Namespace
21+
from contextlib import redirect_stdout
22+
from io import StringIO
2023
from unittest import mock
2124

2225
import pytest
2326
import yaml
2427

25-
from pytorch_lightning import LightningModule, Trainer
28+
from pytorch_lightning import LightningDataModule, LightningModule, Trainer
2629
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
2730
from pytorch_lightning.utilities import _TPU_AVAILABLE
2831
from pytorch_lightning.utilities.cli import LightningArgumentParser, LightningCLI, SaveConfigCallback
@@ -329,3 +332,98 @@ def test_lightning_cli_config_and_subclass_mode(tmpdir):
329332
assert config['model'] == cli.config['model']
330333
assert config['data'] == cli.config['data']
331334
assert config['trainer'] == cli.config['trainer']
335+
336+
337+
def any_model_any_data_cli():
338+
LightningCLI(
339+
LightningModule,
340+
LightningDataModule,
341+
subclass_mode_model=True,
342+
subclass_mode_data=True,
343+
)
344+
345+
346+
def test_lightning_cli_help():
347+
348+
cli_args = ['any.py', '--help']
349+
out = StringIO()
350+
with mock.patch('sys.argv', cli_args), redirect_stdout(out), pytest.raises(SystemExit):
351+
any_model_any_data_cli()
352+
353+
assert '--print_config' in out.getvalue()
354+
assert '--config' in out.getvalue()
355+
assert '--seed_everything' in out.getvalue()
356+
assert '--model.help' in out.getvalue()
357+
assert '--data.help' in out.getvalue()
358+
359+
skip_params = {'self'}
360+
for param in inspect.signature(Trainer.__init__).parameters.keys():
361+
if param not in skip_params:
362+
assert f'--trainer.{param}' in out.getvalue()
363+
364+
cli_args = ['any.py', '--data.help=tests.helpers.BoringDataModule']
365+
out = StringIO()
366+
with mock.patch('sys.argv', cli_args), redirect_stdout(out), pytest.raises(SystemExit):
367+
any_model_any_data_cli()
368+
369+
assert '--data.init_args.data_dir' in out.getvalue()
370+
371+
372+
def test_lightning_cli_print_config():
373+
374+
cli_args = [
375+
'any.py',
376+
'--seed_everything=1234',
377+
'--model=tests.helpers.BoringModel',
378+
'--data=tests.helpers.BoringDataModule',
379+
'--print_config',
380+
]
381+
382+
out = StringIO()
383+
with mock.patch('sys.argv', cli_args), redirect_stdout(out), pytest.raises(SystemExit):
384+
any_model_any_data_cli()
385+
386+
outval = yaml.safe_load(out.getvalue())
387+
assert outval['seed_everything'] == 1234
388+
assert outval['model']['class_path'] == 'tests.helpers.BoringModel'
389+
assert outval['data']['class_path'] == 'tests.helpers.BoringDataModule'
390+
391+
392+
def test_lightning_cli_submodules(tmpdir):
393+
394+
class MainModule(BoringModel):
395+
def __init__(
396+
self,
397+
submodule1: LightningModule,
398+
submodule2: LightningModule,
399+
main_param: int = 1,
400+
):
401+
super().__init__()
402+
self.submodule1 = submodule1
403+
self.submodule2 = submodule2
404+
405+
config = """model:
406+
main_param: 2
407+
submodule1:
408+
class_path: tests.helpers.BoringModel
409+
submodule2:
410+
class_path: tests.helpers.BoringModel
411+
"""
412+
config_path = tmpdir / 'config.yaml'
413+
with open(config_path, 'w') as f:
414+
f.write(config)
415+
416+
cli_args = [
417+
f'--trainer.default_root_dir={tmpdir}',
418+
'--trainer.max_epochs=1',
419+
f'--config={str(config_path)}',
420+
]
421+
422+
with mock.patch('sys.argv', ['any.py'] + cli_args):
423+
cli = LightningCLI(MainModule)
424+
425+
assert cli.config_init['model']['main_param'] == 2
426+
assert cli.model.submodule1 == cli.config_init['model']['submodule1']
427+
assert cli.model.submodule2 == cli.config_init['model']['submodule2']
428+
assert isinstance(cli.config_init['model']['submodule1'], BoringModel)
429+
assert isinstance(cli.config_init['model']['submodule2'], BoringModel)

0 commit comments

Comments
 (0)