11
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
# See the License for the specific language governing permissions and
13
13
# limitations under the License.
14
+ import inspect
14
15
import os
16
+ import sys
15
17
from argparse import Namespace
16
- from types import MethodType
18
+ from types import MethodType , ModuleType
17
19
from typing import Any , Callable , Dict , List , Optional , Set , Tuple , Type , Union
20
+ from unittest import mock
18
21
22
+ import torch
19
23
from torch .optim import Optimizer
20
24
21
25
from pytorch_lightning import Callback , LightningDataModule , LightningModule , seed_everything , Trainer
35
39
ArgumentParser = object
36
40
37
41
42
+ class _Registry (dict ):
43
+ def __call__ (self , cls : Type , key : Optional [str ] = None , override : bool = False ) -> None :
44
+ """Registers a class mapped to a name.
45
+
46
+ Args:
47
+ cls: the class to be mapped.
48
+ key: the name that identifies the provided class.
49
+ override: Whether to override an existing key.
50
+ """
51
+ if key is None :
52
+ key = cls .__name__
53
+ elif not isinstance (key , str ):
54
+ raise TypeError (f"`key` must be a str, found { key } " )
55
+
56
+ if key in self and not override :
57
+ raise MisconfigurationException (f"'{ key } ' is already present in the registry. HINT: Use `override=True`." )
58
+ self [key ] = cls
59
+
60
+ def register_classes (self , module : ModuleType , base_cls : Type , override : bool = False ) -> None :
61
+ """This function is an utility to register all classes from a module."""
62
+ for _ , cls in inspect .getmembers (module , predicate = inspect .isclass ):
63
+ if issubclass (cls , base_cls ) and cls != base_cls :
64
+ self (cls = cls , override = override )
65
+
66
+ @property
67
+ def names (self ) -> List [str ]:
68
+ """Returns the registered names."""
69
+ return list (self .keys ())
70
+
71
+ @property
72
+ def classes (self ) -> Tuple [Type , ...]:
73
+ """Returns the registered classes."""
74
+ return tuple (self .values ())
75
+
76
+ def __str__ (self ) -> str :
77
+ return f"Registered objects: { self .names } "
78
+
79
+
80
+ OPTIMIZER_REGISTRY = _Registry ()
81
+ OPTIMIZER_REGISTRY .register_classes (torch .optim , Optimizer )
82
+
83
+ LR_SCHEDULER_REGISTRY = _Registry ()
84
+ LR_SCHEDULER_REGISTRY .register_classes (torch .optim .lr_scheduler , torch .optim .lr_scheduler ._LRScheduler )
85
+
86
+
38
87
class LightningArgumentParser (ArgumentParser ):
39
88
"""Extension of jsonargparse's ArgumentParser for pytorch-lightning."""
40
89
90
+ # use class attribute because `parse_args` is only called on the main parser
91
+ _choices : Dict [str , Tuple [Type , ...]] = {}
92
+
41
93
def __init__ (self , * args : Any , parse_as_dict : bool = True , ** kwargs : Any ) -> None :
42
94
"""Initialize argument parser that supports configuration file input.
43
95
@@ -118,6 +170,7 @@ def add_optimizer_args(
118
170
kwargs = {"instantiate" : False , "fail_untyped" : False , "skip" : {"params" }}
119
171
if isinstance (optimizer_class , tuple ):
120
172
self .add_subclass_arguments (optimizer_class , nested_key , ** kwargs )
173
+ self .set_choices (nested_key , optimizer_class )
121
174
else :
122
175
self .add_class_arguments (optimizer_class , nested_key , ** kwargs )
123
176
self ._optimizers [nested_key ] = (optimizer_class , link_to )
@@ -142,10 +195,70 @@ def add_lr_scheduler_args(
142
195
kwargs = {"instantiate" : False , "fail_untyped" : False , "skip" : {"optimizer" }}
143
196
if isinstance (lr_scheduler_class , tuple ):
144
197
self .add_subclass_arguments (lr_scheduler_class , nested_key , ** kwargs )
198
+ self .set_choices (nested_key , lr_scheduler_class )
145
199
else :
146
200
self .add_class_arguments (lr_scheduler_class , nested_key , ** kwargs )
147
201
self ._lr_schedulers [nested_key ] = (lr_scheduler_class , link_to )
148
202
203
+ def parse_args (self , * args : Any , ** kwargs : Any ) -> Dict [str , Any ]:
204
+ argv = sys .argv
205
+ for k , classes in self ._choices .items ():
206
+ if not any (arg .startswith (f"--{ k } " ) for arg in argv ):
207
+ # the key wasn't passed - maybe defined in a config, maybe it's optional
208
+ continue
209
+ argv = self ._convert_argv_issue_84 (classes , k , argv )
210
+ self ._choices .clear ()
211
+ with mock .patch ("sys.argv" , argv ):
212
+ return super ().parse_args (* args , ** kwargs )
213
+
214
+ def set_choices (self , nested_key : str , classes : Tuple [Type , ...]) -> None :
215
+ self ._choices [nested_key ] = classes
216
+
217
+ @staticmethod
218
+ def _convert_argv_issue_84 (classes : Tuple [Type , ...], nested_key : str , argv : List [str ]) -> List [str ]:
219
+ """Placeholder for https://github.com/omni-us/jsonargparse/issues/84.
220
+
221
+ This should be removed once implemented.
222
+ """
223
+ passed_args , clean_argv = {}, []
224
+ argv_key = f"--{ nested_key } "
225
+ # get the argv args for this nested key
226
+ i = 0
227
+ while i < len (argv ):
228
+ arg = argv [i ]
229
+ if arg .startswith (argv_key ):
230
+ if "=" in arg :
231
+ key , value = arg .split ("=" )
232
+ else :
233
+ key = arg
234
+ i += 1
235
+ value = argv [i ]
236
+ passed_args [key ] = value
237
+ else :
238
+ clean_argv .append (arg )
239
+ i += 1
240
+ # generate the associated config file
241
+ argv_class = passed_args .pop (argv_key , None )
242
+ if argv_class is None :
243
+ # the user passed a config as a str
244
+ class_path = passed_args [f"{ argv_key } .class_path" ]
245
+ init_args_key = f"{ argv_key } .init_args"
246
+ init_args = {k [len (init_args_key ) + 1 :]: v for k , v in passed_args .items () if k .startswith (init_args_key )}
247
+ config = str ({"class_path" : class_path , "init_args" : init_args })
248
+ elif argv_class .startswith ("{" ):
249
+ # the user passed a config as a dict
250
+ config = argv_class
251
+ else :
252
+ # the user passed the shorthand format
253
+ init_args = {k [len (argv_key ) + 1 :]: v for k , v in passed_args .items ()} # +1 to account for the period
254
+ for cls in classes :
255
+ if cls .__name__ == argv_class :
256
+ config = str (_global_add_class_path (cls , init_args ))
257
+ break
258
+ else :
259
+ raise ValueError (f"Could not generate a config for { repr (argv_class )} " )
260
+ return clean_argv + [argv_key , config ]
261
+
149
262
150
263
class SaveConfigCallback (Callback ):
151
264
"""Saves a LightningCLI config to the log_dir when training starts.
@@ -328,6 +441,11 @@ def _add_arguments(self, parser: LightningArgumentParser) -> None:
328
441
self .add_default_arguments_to_parser (parser )
329
442
self .add_core_arguments_to_parser (parser )
330
443
self .add_arguments_to_parser (parser )
444
+ # add default optimizer args if necessary
445
+ if not parser ._optimizers : # already added by the user in `add_arguments_to_parser`
446
+ parser .add_optimizer_args (OPTIMIZER_REGISTRY .classes )
447
+ if not parser ._lr_schedulers : # already added by the user in `add_arguments_to_parser`
448
+ parser .add_lr_scheduler_args (LR_SCHEDULER_REGISTRY .classes )
331
449
self .link_optimizers_and_lr_schedulers (parser )
332
450
333
451
def add_arguments_to_parser (self , parser : LightningArgumentParser ) -> None :
0 commit comments