Skip to content

Commit bbcb977

Browse files
authored
[CLI] Shorthand notation to instantiate optimizers and lr schedulers [2/3] (#9565)
1 parent 77c719f commit bbcb977

File tree

5 files changed

+443
-51
lines changed

5 files changed

+443
-51
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
5454
* Added `LightningCLI(run=False|True)` to choose whether to run a `Trainer` subcommand ([#8751](https://github.com/PyTorchLightning/pytorch-lightning/pull/8751))
5555
* Added support to call any trainer function from the `LightningCLI` via subcommands ([#7508](https://github.com/PyTorchLightning/pytorch-lightning/pull/7508))
5656
* Allow easy trainer re-instantiation ([#7508](https://github.com/PyTorchLightning/pytorch-lightning/pull/9241))
57+
* Automatically register all optimizers and learning rate schedulers ([#9565](https://github.com/PyTorchLightning/pytorch-lightning/pull/9565))
58+
* Allow registering custom optimizers and learning rate schedulers without subclassing the CLI ([#9565](https://github.com/PyTorchLightning/pytorch-lightning/pull/9565))
59+
* Support shorthand notation to instantiate optimizers and learning rate schedulers ([#9565](https://github.com/PyTorchLightning/pytorch-lightning/pull/9565))
5760

5861

5962
- Fault-tolerant training:

docs/source/common/lightning_cli.rst

Lines changed: 102 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -665,69 +665,135 @@ Optimizers and learning rate schedulers
665665
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
666666
667667
Optimizers and learning rate schedulers can also be made configurable. The most common case is when a model only has a
668-
single optimizer and optionally a single learning rate scheduler. In this case the model's
669-
:class:`~pytorch_lightning.core.lightning.LightningModule` could be left without implementing the
670-
:code:`configure_optimizers` method since it is normally always the same and just adds boilerplate. The following code
671-
snippet shows how to implement it:
668+
single optimizer and optionally a single learning rate scheduler. In this case, the model's
669+
:meth:`~pytorch_lightning.core.lightning.LightningModule.configure_optimizers` could be left unimplemented since it is
670+
normally always the same and just adds boilerplate.
672671
673-
.. testcode::
674-
675-
import torch
676-
677-
678-
class MyLightningCLI(LightningCLI):
679-
def add_arguments_to_parser(self, parser):
680-
parser.add_optimizer_args(torch.optim.Adam)
681-
parser.add_lr_scheduler_args(torch.optim.lr_scheduler.ExponentialLR)
672+
The CLI works out-of-the-box with PyTorch's built-in optimizers and learning rate schedulers when
673+
at most one of each is used.
674+
Only the optimizer or scheduler name needs to be passed, optionally with its ``__init__`` arguments:
682675
676+
.. code-block:: bash
683677
684-
cli = MyLightningCLI(MyModel)
678+
$ python trainer.py fit --optimizer=Adam --optimizer.lr=0.01 --lr_scheduler=ExponentialLR --lr_scheduler.gamma=0.1
685679
686-
With this the :code:`configure_optimizers` method is automatically implemented and in the config the :code:`optimizer`
687-
and :code:`lr_scheduler` groups would accept all of the options for the given classes, in this example :code:`Adam` and
688-
:code:`ExponentialLR`. Therefore, the config file would be structured like:
680+
A corresponding example of the config file would be:
689681
690682
.. code-block:: yaml
691683
692684
optimizer:
693-
lr: 0.01
685+
class_path: torch.optim.Adam
686+
init_args:
687+
lr: 0.01
694688
lr_scheduler:
695-
gamma: 0.2
689+
class_path: torch.optim.lr_scheduler.ExponentialLR
690+
init_args:
691+
gamma: 0.1
696692
model:
697693
...
698694
trainer:
699695
...
700696
701-
And any of these arguments could be passed directly through command line. For example:
697+
.. note::
698+
699+
This short-hand notation is only supported in the shell and not inside a configuration file. The configuration file
700+
generated by calling the previous command with ``--print_config`` will have the ``class_path`` notation.
701+
702+
Furthermore, you can register your own optimizers and/or learning rate schedulers as follows:
703+
704+
.. code-block:: python
705+
706+
from pytorch_lightning.utilities.cli import OPTIMIZER_REGISTRY, LR_SCHEDULER_REGISTRY
707+
708+
709+
@OPTIMIZER_REGISTRY
710+
class CustomAdam(torch.optim.Adam):
711+
...
712+
713+
714+
@LR_SCHEDULER_REGISTRY
715+
class CustomCosineAnnealingLR(torch.optim.lr_scheduler.CosineAnnealingLR):
716+
...
717+
718+
719+
# register all `Optimizer` subclasses from the `torch.optim` package
720+
# This is done automatically!
721+
OPTIMIZER_REGISTRY.register_classes(torch.optim, Optimizer)
722+
723+
cli = LightningCLI(...)
702724
703725
.. code-block:: bash
704726
705-
$ python trainer.py fit --optimizer.lr=0.01 --lr_scheduler.gamma=0.2
727+
$ python trainer.py fit --optimizer=CustomAdam --optimizer.lr=0.01 --lr_scheduler=CustomCosineAnnealingLR
728+
729+
If you need to customize the key names or link arguments together, you can choose from all available optimizers and
730+
learning rate schedulers by accessing the registries.
731+
732+
.. code-block::
733+
734+
class MyLightningCLI(LightningCLI):
735+
def add_arguments_to_parser(self, parser):
736+
parser.add_optimizer_args(
737+
OPTIMIZER_REGISTRY.classes,
738+
nested_key="gen_optimizer",
739+
link_to="model.optimizer1_init"
740+
)
741+
parser.add_optimizer_args(
742+
OPTIMIZER_REGISTRY.classes,
743+
nested_key="gen_discriminator",
744+
link_to="model.optimizer2_init"
745+
)
746+
747+
.. code-block:: bash
748+
749+
$ python trainer.py fit \
750+
--gen_optimizer=Adam \
751+
--gen_optimizer.lr=0.01 \
752+
--gen_discriminator=AdamW \
753+
--gen_discriminator.lr=0.0001
754+
755+
You can also use pass the class path directly, for example, if the optimizer hasn't been registered to the
756+
``OPTIMIZER_REGISTRY``:
757+
758+
.. code-block:: bash
759+
760+
$ python trainer.py fit \
761+
--gen_optimizer.class_path=torch.optim.Adam \
762+
--gen_optimizer.init_args.lr=0.01 \
763+
--gen_discriminator.class_path=torch.optim.AdamW \
764+
--gen_discriminator.init_args.lr=0.0001
706765
707-
There is also the possibility of selecting among multiple classes by giving them as a tuple. For example:
766+
If you will not be changing the class, you can manually add the arguments for specific optimizers and/or
767+
learning rate schedulers by subclassing the CLI. This has the advantage of providing the proper help message for those
768+
classes. The following code snippet shows how to implement it:
708769
709770
.. testcode::
710771
711772
class MyLightningCLI(LightningCLI):
712773
def add_arguments_to_parser(self, parser):
713-
parser.add_optimizer_args((torch.optim.SGD, torch.optim.Adam))
774+
parser.add_optimizer_args(torch.optim.Adam)
775+
parser.add_lr_scheduler_args(torch.optim.lr_scheduler.ExponentialLR)
714776
715-
In this case in the config the :code:`optimizer` group instead of having directly init settings, it should specify
716-
:code:`class_path` and optionally :code:`init_args`. Sub-classes of the classes in the tuple would also be accepted.
717-
A corresponding example of the config file would be:
777+
With this, in the config the :code:`optimizer` and :code:`lr_scheduler` groups would accept all of the options for the
778+
given classes, in this example :code:`Adam` and :code:`ExponentialLR`.
779+
Therefore, the config file would be structured like:
718780
719781
.. code-block:: yaml
720782
721783
optimizer:
722-
class_path: torch.optim.Adam
723-
init_args:
724-
lr: 0.01
784+
lr: 0.01
785+
lr_scheduler:
786+
gamma: 0.2
787+
model:
788+
...
789+
trainer:
790+
...
725791
726-
And the same through command line:
792+
Where the arguments can be passed directly through command line without specifying the class. For example:
727793
728794
.. code-block:: bash
729795
730-
$ python trainer.py fit --optimizer.class_path=torch.optim.Adam --optimizer.init_args.lr=0.01
796+
$ python trainer.py fit --optimizer.lr=0.01 --lr_scheduler.gamma=0.2
731797
732798
The automatic implementation of :code:`configure_optimizers` can be disabled by linking the configuration group. An
733799
example can be :code:`ReduceLROnPlateau` which requires to specify a monitor. This would be:
@@ -763,12 +829,11 @@ example can be :code:`ReduceLROnPlateau` which requires to specify a monitor. Th
763829
764830
cli = MyLightningCLI(MyModel)
765831
766-
For both possibilities of using :meth:`pytorch_lightning.utilities.cli.LightningArgumentParser.add_optimizer_args` with
767-
a single class or a tuple of classes, the value given to :code:`optimizer_init` will always be a dictionary including
768-
:code:`class_path` and :code:`init_args` entries. The function
769-
:func:`~pytorch_lightning.utilities.cli.instantiate_class` takes care of importing the class defined in
770-
:code:`class_path` and instantiating it using some positional arguments, in this case :code:`self.parameters()`, and the
771-
:code:`init_args`. Any number of optimizers and learning rate schedulers can be added when using :code:`link_to`.
832+
The value given to :code:`optimizer_init` will always be a dictionary including :code:`class_path` and
833+
:code:`init_args` entries. The function :func:`~pytorch_lightning.utilities.cli.instantiate_class`
834+
takes care of importing the class defined in :code:`class_path` and instantiating it using some positional arguments,
835+
in this case :code:`self.parameters()`, and the :code:`init_args`.
836+
Any number of optimizers and learning rate schedulers can be added when using :code:`link_to`.
772837
773838
774839
Notes related to reproducibility

pytorch_lightning/utilities/cli.py

Lines changed: 119 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,15 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import inspect
1415
import os
16+
import sys
1517
from argparse import Namespace
16-
from types import MethodType
18+
from types import MethodType, ModuleType
1719
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union
20+
from unittest import mock
1821

22+
import torch
1923
from torch.optim import Optimizer
2024

2125
from pytorch_lightning import Callback, LightningDataModule, LightningModule, seed_everything, Trainer
@@ -35,9 +39,57 @@
3539
ArgumentParser = object
3640

3741

42+
class _Registry(dict):
43+
def __call__(self, cls: Type, key: Optional[str] = None, override: bool = False) -> None:
44+
"""Registers a class mapped to a name.
45+
46+
Args:
47+
cls: the class to be mapped.
48+
key: the name that identifies the provided class.
49+
override: Whether to override an existing key.
50+
"""
51+
if key is None:
52+
key = cls.__name__
53+
elif not isinstance(key, str):
54+
raise TypeError(f"`key` must be a str, found {key}")
55+
56+
if key in self and not override:
57+
raise MisconfigurationException(f"'{key}' is already present in the registry. HINT: Use `override=True`.")
58+
self[key] = cls
59+
60+
def register_classes(self, module: ModuleType, base_cls: Type, override: bool = False) -> None:
61+
"""This function is an utility to register all classes from a module."""
62+
for _, cls in inspect.getmembers(module, predicate=inspect.isclass):
63+
if issubclass(cls, base_cls) and cls != base_cls:
64+
self(cls=cls, override=override)
65+
66+
@property
67+
def names(self) -> List[str]:
68+
"""Returns the registered names."""
69+
return list(self.keys())
70+
71+
@property
72+
def classes(self) -> Tuple[Type, ...]:
73+
"""Returns the registered classes."""
74+
return tuple(self.values())
75+
76+
def __str__(self) -> str:
77+
return f"Registered objects: {self.names}"
78+
79+
80+
OPTIMIZER_REGISTRY = _Registry()
81+
OPTIMIZER_REGISTRY.register_classes(torch.optim, Optimizer)
82+
83+
LR_SCHEDULER_REGISTRY = _Registry()
84+
LR_SCHEDULER_REGISTRY.register_classes(torch.optim.lr_scheduler, torch.optim.lr_scheduler._LRScheduler)
85+
86+
3887
class LightningArgumentParser(ArgumentParser):
3988
"""Extension of jsonargparse's ArgumentParser for pytorch-lightning."""
4089

90+
# use class attribute because `parse_args` is only called on the main parser
91+
_choices: Dict[str, Tuple[Type, ...]] = {}
92+
4193
def __init__(self, *args: Any, parse_as_dict: bool = True, **kwargs: Any) -> None:
4294
"""Initialize argument parser that supports configuration file input.
4395
@@ -118,6 +170,7 @@ def add_optimizer_args(
118170
kwargs = {"instantiate": False, "fail_untyped": False, "skip": {"params"}}
119171
if isinstance(optimizer_class, tuple):
120172
self.add_subclass_arguments(optimizer_class, nested_key, **kwargs)
173+
self.set_choices(nested_key, optimizer_class)
121174
else:
122175
self.add_class_arguments(optimizer_class, nested_key, **kwargs)
123176
self._optimizers[nested_key] = (optimizer_class, link_to)
@@ -142,10 +195,70 @@ def add_lr_scheduler_args(
142195
kwargs = {"instantiate": False, "fail_untyped": False, "skip": {"optimizer"}}
143196
if isinstance(lr_scheduler_class, tuple):
144197
self.add_subclass_arguments(lr_scheduler_class, nested_key, **kwargs)
198+
self.set_choices(nested_key, lr_scheduler_class)
145199
else:
146200
self.add_class_arguments(lr_scheduler_class, nested_key, **kwargs)
147201
self._lr_schedulers[nested_key] = (lr_scheduler_class, link_to)
148202

203+
def parse_args(self, *args: Any, **kwargs: Any) -> Dict[str, Any]:
204+
argv = sys.argv
205+
for k, classes in self._choices.items():
206+
if not any(arg.startswith(f"--{k}") for arg in argv):
207+
# the key wasn't passed - maybe defined in a config, maybe it's optional
208+
continue
209+
argv = self._convert_argv_issue_84(classes, k, argv)
210+
self._choices.clear()
211+
with mock.patch("sys.argv", argv):
212+
return super().parse_args(*args, **kwargs)
213+
214+
def set_choices(self, nested_key: str, classes: Tuple[Type, ...]) -> None:
215+
self._choices[nested_key] = classes
216+
217+
@staticmethod
218+
def _convert_argv_issue_84(classes: Tuple[Type, ...], nested_key: str, argv: List[str]) -> List[str]:
219+
"""Placeholder for https://github.com/omni-us/jsonargparse/issues/84.
220+
221+
This should be removed once implemented.
222+
"""
223+
passed_args, clean_argv = {}, []
224+
argv_key = f"--{nested_key}"
225+
# get the argv args for this nested key
226+
i = 0
227+
while i < len(argv):
228+
arg = argv[i]
229+
if arg.startswith(argv_key):
230+
if "=" in arg:
231+
key, value = arg.split("=")
232+
else:
233+
key = arg
234+
i += 1
235+
value = argv[i]
236+
passed_args[key] = value
237+
else:
238+
clean_argv.append(arg)
239+
i += 1
240+
# generate the associated config file
241+
argv_class = passed_args.pop(argv_key, None)
242+
if argv_class is None:
243+
# the user passed a config as a str
244+
class_path = passed_args[f"{argv_key}.class_path"]
245+
init_args_key = f"{argv_key}.init_args"
246+
init_args = {k[len(init_args_key) + 1 :]: v for k, v in passed_args.items() if k.startswith(init_args_key)}
247+
config = str({"class_path": class_path, "init_args": init_args})
248+
elif argv_class.startswith("{"):
249+
# the user passed a config as a dict
250+
config = argv_class
251+
else:
252+
# the user passed the shorthand format
253+
init_args = {k[len(argv_key) + 1 :]: v for k, v in passed_args.items()} # +1 to account for the period
254+
for cls in classes:
255+
if cls.__name__ == argv_class:
256+
config = str(_global_add_class_path(cls, init_args))
257+
break
258+
else:
259+
raise ValueError(f"Could not generate a config for {repr(argv_class)}")
260+
return clean_argv + [argv_key, config]
261+
149262

150263
class SaveConfigCallback(Callback):
151264
"""Saves a LightningCLI config to the log_dir when training starts.
@@ -328,6 +441,11 @@ def _add_arguments(self, parser: LightningArgumentParser) -> None:
328441
self.add_default_arguments_to_parser(parser)
329442
self.add_core_arguments_to_parser(parser)
330443
self.add_arguments_to_parser(parser)
444+
# add default optimizer args if necessary
445+
if not parser._optimizers: # already added by the user in `add_arguments_to_parser`
446+
parser.add_optimizer_args(OPTIMIZER_REGISTRY.classes)
447+
if not parser._lr_schedulers: # already added by the user in `add_arguments_to_parser`
448+
parser.add_lr_scheduler_args(LR_SCHEDULER_REGISTRY.classes)
331449
self.link_optimizers_and_lr_schedulers(parser)
332450

333451
def add_arguments_to_parser(self, parser: LightningArgumentParser) -> None:

requirements/extra.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,6 @@ torchtext>=0.7
77
onnx>=1.7.0
88
onnxruntime>=1.3.0
99
hydra-core>=1.0
10-
jsonargparse[signatures]>=3.19.0
10+
jsonargparse[signatures]>=3.19.3
1111
gcsfs>=2021.5.0
1212
rich>=10.2.2

0 commit comments

Comments
 (0)