Skip to content

Commit eb6aa7a

Browse files
authored
[CLI] Add option to enable/disable config save to preserve multiple files structure
1 parent 4bebf82 commit eb6aa7a

File tree

2 files changed

+23
-2
lines changed

2 files changed

+23
-2
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
142142
- Added `PL_RECONCILE_PROCESS` environment variable to enable process reconciliation regardless of cluster environment settings ([#9389](https://github.com/PyTorchLightning/pytorch-lightning/pull/9389))
143143

144144

145+
- Added `multifile` option to `LightningCLI` to enable/disable config save to preserve multiple files structure ([#9073](https://github.com/PyTorchLightning/pytorch-lightning/pull/9073))
146+
147+
145148
- Added `RichModelSummary` callback ([#9546](https://github.com/PyTorchLightning/pytorch-lightning/pull/9546))
146149

147150

pytorch_lightning/utilities/cli.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -341,6 +341,13 @@ def _convert_argv_issue_85(classes: Tuple[Type, ...], nested_key: str, argv: Lis
341341
class SaveConfigCallback(Callback):
342342
"""Saves a LightningCLI config to the log_dir when training starts.
343343
344+
Args:
345+
parser: The parser object used to parse the configuration.
346+
config: The parsed configuration that will be saved.
347+
config_filename: Filename for the config file.
348+
overwrite: Whether to overwrite an existing config file.
349+
multifile: When input is multiple config files, saved config preserves this structure.
350+
344351
Raises:
345352
RuntimeError: If the config file already exists in the directory to avoid overwriting a previous run
346353
"""
@@ -351,11 +358,13 @@ def __init__(
351358
config: Union[Namespace, Dict[str, Any]],
352359
config_filename: str,
353360
overwrite: bool = False,
361+
multifile: bool = False,
354362
) -> None:
355363
self.parser = parser
356364
self.config = config
357365
self.config_filename = config_filename
358366
self.overwrite = overwrite
367+
self.multifile = multifile
359368

360369
def setup(self, trainer: Trainer, pl_module: LightningModule, stage: Optional[str] = None) -> None:
361370
# save the config in `setup` because (1) we want it to save regardless of the trainer function run
@@ -375,7 +384,9 @@ def setup(self, trainer: Trainer, pl_module: LightningModule, stage: Optional[st
375384
# the `log_dir` needs to be created as we rely on the logger to do it usually
376385
# but it hasn't logged anything at this point
377386
get_filesystem(log_dir).makedirs(log_dir, exist_ok=True)
378-
self.parser.save(self.config, config_path, skip_none=False, overwrite=self.overwrite)
387+
self.parser.save(
388+
self.config, config_path, skip_none=False, overwrite=self.overwrite, multifile=self.multifile
389+
)
379390

380391
def __reduce__(self) -> Tuple[Type["SaveConfigCallback"], Tuple, Dict]:
381392
# `ArgumentParser` is un-pickleable. Drop it
@@ -392,6 +403,7 @@ def __init__(
392403
save_config_callback: Optional[Type[SaveConfigCallback]] = SaveConfigCallback,
393404
save_config_filename: str = "config.yaml",
394405
save_config_overwrite: bool = False,
406+
save_config_multifile: bool = False,
395407
trainer_class: Union[Type[Trainer], Callable[..., Trainer]] = Trainer,
396408
trainer_defaults: Optional[Dict[str, Any]] = None,
397409
seed_everything_default: Optional[int] = None,
@@ -424,6 +436,7 @@ def __init__(
424436
save_config_callback: A callback class to save the training config.
425437
save_config_filename: Filename for the config file.
426438
save_config_overwrite: Whether to overwrite an existing config file.
439+
save_config_multifile: When input is multiple config files, saved config preserves this structure.
427440
trainer_class: An optional subclass of the :class:`~pytorch_lightning.trainer.trainer.Trainer` class or a
428441
callable which returns a :class:`~pytorch_lightning.trainer.trainer.Trainer` instance when called.
429442
trainer_defaults: Set to override Trainer defaults or add persistent callbacks.
@@ -446,6 +459,7 @@ def __init__(
446459
self.save_config_callback = save_config_callback
447460
self.save_config_filename = save_config_filename
448461
self.save_config_overwrite = save_config_overwrite
462+
self.save_config_multifile = save_config_multifile
449463
self.trainer_class = trainer_class
450464
self.trainer_defaults = trainer_defaults or {}
451465
self.seed_everything_default = seed_everything_default
@@ -627,7 +641,11 @@ def _instantiate_trainer(self, config: Dict[str, Any], callbacks: List[Callback]
627641
config["callbacks"].append(self.trainer_defaults["callbacks"])
628642
if self.save_config_callback and not config["fast_dev_run"]:
629643
config_callback = self.save_config_callback(
630-
self.parser, self.config, self.save_config_filename, overwrite=self.save_config_overwrite
644+
self.parser,
645+
self.config,
646+
self.save_config_filename,
647+
overwrite=self.save_config_overwrite,
648+
multifile=self.save_config_multifile,
631649
)
632650
config["callbacks"].append(config_callback)
633651
return self.trainer_class(**config)

0 commit comments

Comments
 (0)