Skip to content

Commit 1bb5fcc

Browse files
tchatoncarmocca
andauthored
[CLI] Shorthand notation to instantiate callbacks [3/3] (#8815)
Co-authored-by: Carlos Mocholi <[email protected]>
1 parent bbcb977 commit 1bb5fcc

File tree

5 files changed

+241
-12
lines changed

5 files changed

+241
-12
lines changed

.github/workflows/code-checks.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,4 +19,4 @@ jobs:
1919
run: |
2020
grep mypy requirements/test.txt | xargs -0 pip install
2121
pip list
22-
- run: mypy
22+
- run: mypy --install-types --non-interactive

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
5757
* Automatically register all optimizers and learning rate schedulers ([#9565](https://github.com/PyTorchLightning/pytorch-lightning/pull/9565))
5858
* Allow registering custom optimizers and learning rate schedulers without subclassing the CLI ([#9565](https://github.com/PyTorchLightning/pytorch-lightning/pull/9565))
5959
* Support shorthand notation to instantiate optimizers and learning rate schedulers ([#9565](https://github.com/PyTorchLightning/pytorch-lightning/pull/9565))
60+
* Support passing lists of callbacks via command line ([#8815](https://github.com/PyTorchLightning/pytorch-lightning/pull/8815))
6061

6162

6263
- Fault-tolerant training:

docs/source/common/lightning_cli.rst

Lines changed: 58 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from unittest import mock
66
from typing import List
77
import pytorch_lightning as pl
8-
from pytorch_lightning import LightningModule, LightningDataModule, Trainer
8+
from pytorch_lightning import LightningModule, LightningDataModule, Trainer, Callback
99

1010

1111
class NoFitTrainer(Trainer):
@@ -371,6 +371,59 @@ Similar to the callbacks, any arguments in :class:`~pytorch_lightning.trainer.tr
371371
:class:`~pytorch_lightning.core.datamodule.LightningDataModule` classes that have as type hint a class can be configured
372372
the same way using :code:`class_path` and :code:`init_args`.
373373
374+
For callbacks in particular, Lightning simplifies the command line so that only
375+
the :class:`~pytorch_lightning.callbacks.Callback` name is required.
376+
The argument's order matters and the user needs to pass the arguments in the following way.
377+
378+
.. code-block:: bash
379+
380+
$ python ... \
381+
--trainer.callbacks={CALLBACK_1_NAME} \
382+
--trainer.callbacks.{CALLBACK_1_ARGS_1}=... \
383+
--trainer.callbacks.{CALLBACK_1_ARGS_2}=... \
384+
...
385+
--trainer.callbacks={CALLBACK_N_NAME} \
386+
--trainer.callbacks.{CALLBACK_N_ARGS_1}=... \
387+
...
388+
389+
Here is an example:
390+
391+
.. code-block:: bash
392+
393+
$ python ... \
394+
--trainer.callbacks=EarlyStopping \
395+
--trainer.callbacks.patience=5 \
396+
--trainer.callbacks=LearningRateMonitor \
397+
--trainer.callbacks.logging_interval=epoch
398+
399+
Lightning provides a mechanism for you to add your own callbacks and benefit from the command line simplification
400+
as described above:
401+
402+
.. code-block:: python
403+
404+
from pytorch_lightning.utilities.cli import CALLBACK_REGISTRY
405+
406+
407+
@CALLBACK_REGISTRY
408+
class CustomCallback(Callback):
409+
...
410+
411+
412+
cli = LightningCLI(...)
413+
414+
.. code-block:: bash
415+
416+
$ python ... --trainer.callbacks=CustomCallback ...
417+
418+
This callback will be included in the generated config:
419+
420+
.. code-block:: yaml
421+
422+
trainer:
423+
callbacks:
424+
- class_path: your_class_path.CustomCallback
425+
init_args:
426+
...
374427
375428
Multiple models and/or datasets
376429
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
@@ -517,9 +570,10 @@ instantiating the trainer class can be found in :code:`self.config['fit']['train
517570
Configurable callbacks
518571
^^^^^^^^^^^^^^^^^^^^^^
519572
520-
As explained previously, any callback can be added by including it in the config via :code:`class_path` and
521-
:code:`init_args` entries. However, there are other cases in which a callback should always be present and be
522-
configurable. This can be implemented as follows:
573+
As explained previously, any Lightning callback can be added by passing it through command line or
574+
including it in the config via :code:`class_path` and :code:`init_args` entries.
575+
However, there are other cases in which a callback should always be present and be configurable.
576+
This can be implemented as follows:
523577
524578
.. testcode::
525579

pytorch_lightning/utilities/cli.py

Lines changed: 82 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,10 @@
2020
from unittest import mock
2121

2222
import torch
23+
import yaml
2324
from torch.optim import Optimizer
2425

26+
import pytorch_lightning as pl
2527
from pytorch_lightning import Callback, LightningDataModule, LightningModule, seed_everything, Trainer
2628
from pytorch_lightning.utilities import _JSONARGPARSE_AVAILABLE, rank_zero_warn, warnings
2729
from pytorch_lightning.utilities.cloud_io import get_filesystem
@@ -83,12 +85,15 @@ def __str__(self) -> str:
8385
LR_SCHEDULER_REGISTRY = _Registry()
8486
LR_SCHEDULER_REGISTRY.register_classes(torch.optim.lr_scheduler, torch.optim.lr_scheduler._LRScheduler)
8587

88+
CALLBACK_REGISTRY = _Registry()
89+
CALLBACK_REGISTRY.register_classes(pl.callbacks, pl.callbacks.Callback)
90+
8691

8792
class LightningArgumentParser(ArgumentParser):
8893
"""Extension of jsonargparse's ArgumentParser for pytorch-lightning."""
8994

9095
# use class attribute because `parse_args` is only called on the main parser
91-
_choices: Dict[str, Tuple[Type, ...]] = {}
96+
_choices: Dict[str, Tuple[Tuple[Type, ...], bool]] = {}
9297

9398
def __init__(self, *args: Any, parse_as_dict: bool = True, **kwargs: Any) -> None:
9499
"""Initialize argument parser that supports configuration file input.
@@ -202,23 +207,35 @@ def add_lr_scheduler_args(
202207

203208
def parse_args(self, *args: Any, **kwargs: Any) -> Dict[str, Any]:
204209
argv = sys.argv
205-
for k, classes in self._choices.items():
210+
for k, v in self._choices.items():
206211
if not any(arg.startswith(f"--{k}") for arg in argv):
207212
# the key wasn't passed - maybe defined in a config, maybe it's optional
208213
continue
209-
argv = self._convert_argv_issue_84(classes, k, argv)
214+
classes, is_list = v
215+
# knowing whether the argument is a list type automatically would be too complex
216+
if is_list:
217+
argv = self._convert_argv_issue_85(classes, k, argv)
218+
else:
219+
argv = self._convert_argv_issue_84(classes, k, argv)
210220
self._choices.clear()
211221
with mock.patch("sys.argv", argv):
212222
return super().parse_args(*args, **kwargs)
213223

214-
def set_choices(self, nested_key: str, classes: Tuple[Type, ...]) -> None:
215-
self._choices[nested_key] = classes
224+
def set_choices(self, nested_key: str, classes: Tuple[Type, ...], is_list: bool = False) -> None:
225+
"""Adds support for shorthand notation for a particular nested key.
226+
227+
Args:
228+
nested_key: The key whose choices will be set.
229+
classes: A tuple of classes to choose from.
230+
is_list: Whether the argument is a ``List[object]`` type.
231+
"""
232+
self._choices[nested_key] = (classes, is_list)
216233

217234
@staticmethod
218235
def _convert_argv_issue_84(classes: Tuple[Type, ...], nested_key: str, argv: List[str]) -> List[str]:
219236
"""Placeholder for https://github.com/omni-us/jsonargparse/issues/84.
220237
221-
This should be removed once implemented.
238+
Adds support for shorthand notation for ``object`` arguments.
222239
"""
223240
passed_args, clean_argv = {}, []
224241
argv_key = f"--{nested_key}"
@@ -259,6 +276,64 @@ def _convert_argv_issue_84(classes: Tuple[Type, ...], nested_key: str, argv: Lis
259276
raise ValueError(f"Could not generate a config for {repr(argv_class)}")
260277
return clean_argv + [argv_key, config]
261278

279+
@staticmethod
280+
def _convert_argv_issue_85(classes: Tuple[Type, ...], nested_key: str, argv: List[str]) -> List[str]:
281+
"""Placeholder for https://github.com/omni-us/jsonargparse/issues/85.
282+
283+
Adds support for shorthand notation for ``List[object]`` arguments.
284+
"""
285+
passed_args, clean_argv = [], []
286+
passed_configs = {}
287+
argv_key = f"--{nested_key}"
288+
# get the argv args for this nested key
289+
i = 0
290+
while i < len(argv):
291+
arg = argv[i]
292+
if arg.startswith(argv_key):
293+
if "=" in arg:
294+
key, value = arg.split("=")
295+
else:
296+
key = arg
297+
i += 1
298+
value = argv[i]
299+
if "class_path" in value:
300+
# the user passed a config as a dict
301+
passed_configs[key] = yaml.safe_load(value)
302+
else:
303+
passed_args.append((key, value))
304+
else:
305+
clean_argv.append(arg)
306+
i += 1
307+
# generate the associated config file
308+
config = []
309+
i, n = 0, len(passed_args)
310+
while i < n - 1:
311+
ki, vi = passed_args[i]
312+
# convert class name to class path
313+
for cls in classes:
314+
if cls.__name__ == vi:
315+
cls_type = cls
316+
break
317+
else:
318+
raise ValueError(f"Could not generate a config for {repr(vi)}")
319+
config.append(_global_add_class_path(cls_type))
320+
# get any init args
321+
j = i + 1 # in case the j-loop doesn't run
322+
for j in range(i + 1, n):
323+
kj, vj = passed_args[j]
324+
if ki == kj:
325+
break
326+
if kj.startswith(ki):
327+
init_arg_name = kj.split(".")[-1]
328+
config[-1]["init_args"][init_arg_name] = vj
329+
i = j
330+
# update at the end to preserve the order
331+
for k, v in passed_configs.items():
332+
config.extend(v)
333+
if not config:
334+
return clean_argv
335+
return clean_argv + [argv_key, str(config)]
336+
262337

263338
class SaveConfigCallback(Callback):
264339
"""Saves a LightningCLI config to the log_dir when training starts.
@@ -430,6 +505,7 @@ def add_default_arguments_to_parser(self, parser: LightningArgumentParser) -> No
430505
def add_core_arguments_to_parser(self, parser: LightningArgumentParser) -> None:
431506
"""Adds arguments from the core classes to the parser."""
432507
parser.add_lightning_class_args(self.trainer_class, "trainer")
508+
parser.set_choices("trainer.callbacks", CALLBACK_REGISTRY.classes, is_list=True)
433509
trainer_defaults = {"trainer." + k: v for k, v in self.trainer_defaults.items() if k != "callbacks"}
434510
parser.set_defaults(trainer_defaults)
435511
parser.add_lightning_class_args(self.model_class, "model", subclass_mode=self.subclass_mode_model)

tests/utilities/test_cli.py

Lines changed: 99 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
from pytorch_lightning.trainer.states import TrainerFn
3535
from pytorch_lightning.utilities import _TPU_AVAILABLE
3636
from pytorch_lightning.utilities.cli import (
37+
CALLBACK_REGISTRY,
3738
instantiate_class,
3839
LightningArgumentParser,
3940
LightningCLI,
@@ -861,6 +862,11 @@ class CustomCosineAnnealingLR(torch.optim.lr_scheduler.CosineAnnealingLR):
861862
pass
862863

863864

865+
@CALLBACK_REGISTRY
866+
class CustomCallback(Callback):
867+
pass
868+
869+
864870
def test_registries(tmpdir):
865871
assert "SGD" in OPTIMIZER_REGISTRY.names
866872
assert "RMSprop" in OPTIMIZER_REGISTRY.names
@@ -870,23 +876,41 @@ def test_registries(tmpdir):
870876
assert "CosineAnnealingWarmRestarts" in LR_SCHEDULER_REGISTRY.names
871877
assert "CustomCosineAnnealingLR" in LR_SCHEDULER_REGISTRY.names
872878

879+
assert "EarlyStopping" in CALLBACK_REGISTRY.names
880+
assert "CustomCallback" in CALLBACK_REGISTRY.names
881+
873882
with pytest.raises(MisconfigurationException, match="is already present in the registry"):
874883
OPTIMIZER_REGISTRY.register_classes(torch.optim, torch.optim.Optimizer)
875884
OPTIMIZER_REGISTRY.register_classes(torch.optim, torch.optim.Optimizer, override=True)
876885

877886

878-
def test_registries_resolution():
887+
@pytest.mark.parametrize("use_class_path_callbacks", [False, True])
888+
def test_registries_resolution(use_class_path_callbacks):
879889
"""This test validates registries are used when simplified command line are being used."""
880890
cli_args = [
881891
"--optimizer",
882892
"Adam",
883893
"--optimizer.lr",
884894
"0.0001",
895+
"--trainer.callbacks=LearningRateMonitor",
896+
"--trainer.callbacks.logging_interval=epoch",
897+
"--trainer.callbacks.log_momentum=True",
898+
"--trainer.callbacks=ModelCheckpoint",
899+
"--trainer.callbacks.monitor=loss",
885900
"--lr_scheduler",
886901
"StepLR",
887902
"--lr_scheduler.step_size=50",
888903
]
889904

905+
extras = []
906+
if use_class_path_callbacks:
907+
callbacks = [
908+
{"class_path": "pytorch_lightning.callbacks.Callback"},
909+
{"class_path": "pytorch_lightning.callbacks.Callback", "init_args": {}},
910+
]
911+
cli_args += [f"--trainer.callbacks={json.dumps(callbacks)}"]
912+
extras = [Callback, Callback]
913+
890914
with mock.patch("sys.argv", ["any.py"] + cli_args):
891915
cli = LightningCLI(BoringModel, run=False)
892916

@@ -895,6 +919,80 @@ def test_registries_resolution():
895919
assert optimizers[0].param_groups[0]["lr"] == 0.0001
896920
assert lr_scheduler[0].step_size == 50
897921

922+
callback_types = [type(c) for c in cli.trainer.callbacks]
923+
expected = [LearningRateMonitor, SaveConfigCallback, ModelCheckpoint] + extras
924+
assert all(t in callback_types for t in expected)
925+
926+
927+
def test_argv_transformation_noop():
928+
base = ["any.py", "--trainer.max_epochs=1"]
929+
argv = LightningArgumentParser._convert_argv_issue_85(CALLBACK_REGISTRY.classes, "trainer.callbacks", base)
930+
assert argv == base
931+
932+
933+
def test_argv_transformation_single_callback():
934+
base = ["any.py", "--trainer.max_epochs=1"]
935+
input = base + ["--trainer.callbacks=ModelCheckpoint", "--trainer.callbacks.monitor=val_loss"]
936+
callbacks = [
937+
{
938+
"class_path": "pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint",
939+
"init_args": {"monitor": "val_loss"},
940+
}
941+
]
942+
expected = base + ["--trainer.callbacks", str(callbacks)]
943+
argv = LightningArgumentParser._convert_argv_issue_85(CALLBACK_REGISTRY.classes, "trainer.callbacks", input)
944+
assert argv == expected
945+
946+
947+
def test_argv_transformation_multiple_callbacks():
948+
base = ["any.py", "--trainer.max_epochs=1"]
949+
input = base + [
950+
"--trainer.callbacks=ModelCheckpoint",
951+
"--trainer.callbacks.monitor=val_loss",
952+
"--trainer.callbacks=ModelCheckpoint",
953+
"--trainer.callbacks.monitor=val_acc",
954+
]
955+
callbacks = [
956+
{
957+
"class_path": "pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint",
958+
"init_args": {"monitor": "val_loss"},
959+
},
960+
{
961+
"class_path": "pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint",
962+
"init_args": {"monitor": "val_acc"},
963+
},
964+
]
965+
expected = base + ["--trainer.callbacks", str(callbacks)]
966+
argv = LightningArgumentParser._convert_argv_issue_85(CALLBACK_REGISTRY.classes, "trainer.callbacks", input)
967+
assert argv == expected
968+
969+
970+
def test_argv_transformation_multiple_callbacks_with_config():
971+
base = ["any.py", "--trainer.max_epochs=1"]
972+
nested_key = "trainer.callbacks"
973+
input = base + [
974+
f"--{nested_key}=ModelCheckpoint",
975+
f"--{nested_key}.monitor=val_loss",
976+
f"--{nested_key}=ModelCheckpoint",
977+
f"--{nested_key}.monitor=val_acc",
978+
f"--{nested_key}=[{{'class_path': 'pytorch_lightning.callbacks.Callback'}}]",
979+
]
980+
callbacks = [
981+
{
982+
"class_path": "pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint",
983+
"init_args": {"monitor": "val_loss"},
984+
},
985+
{
986+
"class_path": "pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint",
987+
"init_args": {"monitor": "val_acc"},
988+
},
989+
{"class_path": "pytorch_lightning.callbacks.Callback"},
990+
]
991+
expected = base + ["--trainer.callbacks", str(callbacks)]
992+
nested_key = "trainer.callbacks"
993+
argv = LightningArgumentParser._convert_argv_issue_85(CALLBACK_REGISTRY.classes, nested_key, input)
994+
assert argv == expected
995+
898996

899997
@pytest.mark.parametrize(
900998
["args", "expected", "nested_key", "registry"],

0 commit comments

Comments
 (0)