@@ -341,6 +341,13 @@ def _convert_argv_issue_85(classes: Tuple[Type, ...], nested_key: str, argv: Lis
341
341
class SaveConfigCallback (Callback ):
342
342
"""Saves a LightningCLI config to the log_dir when training starts.
343
343
344
+ Args:
345
+ parser: The parser object used to parse the configuration.
346
+ config: The parsed configuration that will be saved.
347
+ config_filename: Filename for the config file.
348
+ overwrite: Whether to overwrite an existing config file.
349
+ multifile: When input is multiple config files, saved config preserves this structure.
350
+
344
351
Raises:
345
352
RuntimeError: If the config file already exists in the directory to avoid overwriting a previous run
346
353
"""
@@ -351,11 +358,13 @@ def __init__(
351
358
config : Union [Namespace , Dict [str , Any ]],
352
359
config_filename : str ,
353
360
overwrite : bool = False ,
361
+ multifile : bool = False ,
354
362
) -> None :
355
363
self .parser = parser
356
364
self .config = config
357
365
self .config_filename = config_filename
358
366
self .overwrite = overwrite
367
+ self .multifile = multifile
359
368
360
369
def setup (self , trainer : Trainer , pl_module : LightningModule , stage : Optional [str ] = None ) -> None :
361
370
# save the config in `setup` because (1) we want it to save regardless of the trainer function run
@@ -375,7 +384,9 @@ def setup(self, trainer: Trainer, pl_module: LightningModule, stage: Optional[st
375
384
# the `log_dir` needs to be created as we rely on the logger to do it usually
376
385
# but it hasn't logged anything at this point
377
386
get_filesystem (log_dir ).makedirs (log_dir , exist_ok = True )
378
- self .parser .save (self .config , config_path , skip_none = False , overwrite = self .overwrite )
387
+ self .parser .save (
388
+ self .config , config_path , skip_none = False , overwrite = self .overwrite , multifile = self .multifile
389
+ )
379
390
380
391
def __reduce__ (self ) -> Tuple [Type ["SaveConfigCallback" ], Tuple , Dict ]:
381
392
# `ArgumentParser` is un-pickleable. Drop it
@@ -392,6 +403,7 @@ def __init__(
392
403
save_config_callback : Optional [Type [SaveConfigCallback ]] = SaveConfigCallback ,
393
404
save_config_filename : str = "config.yaml" ,
394
405
save_config_overwrite : bool = False ,
406
+ save_config_multifile : bool = False ,
395
407
trainer_class : Union [Type [Trainer ], Callable [..., Trainer ]] = Trainer ,
396
408
trainer_defaults : Optional [Dict [str , Any ]] = None ,
397
409
seed_everything_default : Optional [int ] = None ,
@@ -424,6 +436,7 @@ def __init__(
424
436
save_config_callback: A callback class to save the training config.
425
437
save_config_filename: Filename for the config file.
426
438
save_config_overwrite: Whether to overwrite an existing config file.
439
+ save_config_multifile: When input is multiple config files, saved config preserves this structure.
427
440
trainer_class: An optional subclass of the :class:`~pytorch_lightning.trainer.trainer.Trainer` class or a
428
441
callable which returns a :class:`~pytorch_lightning.trainer.trainer.Trainer` instance when called.
429
442
trainer_defaults: Set to override Trainer defaults or add persistent callbacks.
@@ -446,6 +459,7 @@ def __init__(
446
459
self .save_config_callback = save_config_callback
447
460
self .save_config_filename = save_config_filename
448
461
self .save_config_overwrite = save_config_overwrite
462
+ self .save_config_multifile = save_config_multifile
449
463
self .trainer_class = trainer_class
450
464
self .trainer_defaults = trainer_defaults or {}
451
465
self .seed_everything_default = seed_everything_default
@@ -627,7 +641,11 @@ def _instantiate_trainer(self, config: Dict[str, Any], callbacks: List[Callback]
627
641
config ["callbacks" ].append (self .trainer_defaults ["callbacks" ])
628
642
if self .save_config_callback and not config ["fast_dev_run" ]:
629
643
config_callback = self .save_config_callback (
630
- self .parser , self .config , self .save_config_filename , overwrite = self .save_config_overwrite
644
+ self .parser ,
645
+ self .config ,
646
+ self .save_config_filename ,
647
+ overwrite = self .save_config_overwrite ,
648
+ multifile = self .save_config_multifile ,
631
649
)
632
650
config ["callbacks" ].append (config_callback )
633
651
return self .trainer_class (** config )
0 commit comments