Skip to content

Commit a2dd6a0

Browse files
carmoccaawaelchli
authored andcommitted
Fix CLI race condition saving the config (#11199)
1 parent a5f7b5d commit a2dd6a0

File tree

2 files changed

+19
-9
lines changed

2 files changed

+19
-9
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1111

1212
-
1313

14+
- Fixed `LightningCLI` race condition while saving the config ([#11199](https://github.com/PyTorchLightning/pytorch-lightning/pull/11199))
1415

1516
## [1.5.7] - 2021-12-21
1617

pytorch_lightning/utilities/cli.py

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -395,21 +395,30 @@ def __init__(
395395
def setup(self, trainer: Trainer, pl_module: LightningModule, stage: Optional[str] = None) -> None:
396396
# save the config in `setup` because (1) we want it to save regardless of the trainer function run
397397
# and we want to save before processes are spawned
398-
log_dir = trainer.log_dir
398+
log_dir = trainer.log_dir # this broadcasts the directory
399399
assert log_dir is not None
400400
config_path = os.path.join(log_dir, self.config_filename)
401-
if not self.overwrite and os.path.isfile(config_path):
402-
raise RuntimeError(
403-
f"{self.__class__.__name__} expected {config_path} to NOT exist. Aborting to avoid overwriting"
404-
" results of a previous run. You can delete the previous config file,"
405-
" set `LightningCLI(save_config_callback=None)` to disable config saving,"
406-
" or set `LightningCLI(save_config_overwrite=True)` to overwrite the config file."
407-
)
401+
fs = get_filesystem(log_dir)
402+
403+
if not self.overwrite:
404+
# check if the file exists on rank 0
405+
file_exists = fs.isfile(config_path) if trainer.is_global_zero else False
406+
# broadcast whether to fail to all ranks
407+
file_exists = trainer.strategy.broadcast(file_exists)
408+
if file_exists:
409+
raise RuntimeError(
410+
f"{self.__class__.__name__} expected {config_path} to NOT exist. Aborting to avoid overwriting"
411+
" results of a previous run. You can delete the previous config file,"
412+
" set `LightningCLI(save_config_callback=None)` to disable config saving,"
413+
" or set `LightningCLI(save_config_overwrite=True)` to overwrite the config file."
414+
)
415+
416+
# save the file on rank 0
408417
if trainer.is_global_zero:
409418
# save only on rank zero to avoid race conditions on DDP.
410419
# the `log_dir` needs to be created as we rely on the logger to do it usually
411420
# but it hasn't logged anything at this point
412-
get_filesystem(log_dir).makedirs(log_dir, exist_ok=True)
421+
fs.makedirs(log_dir, exist_ok=True)
413422
self.parser.save(
414423
self.config, config_path, skip_none=False, overwrite=self.overwrite, multifile=self.multifile
415424
)

0 commit comments

Comments
 (0)