Skip to content

Commit a9024ce

Browse files
authored
[CLI] Fix SaveConfigCallback with DDP spawn (#12011)
1 parent 01c31ae commit a9024ce

File tree

2 files changed

+14
-22
lines changed

2 files changed

+14
-22
lines changed

pytorch_lightning/utilities/cli.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -415,8 +415,6 @@ def __init__(
415415
self.multifile = multifile
416416

417417
def setup(self, trainer: Trainer, pl_module: LightningModule, stage: Optional[str] = None) -> None:
418-
# save the config in `setup` because (1) we want it to save regardless of the trainer function run
419-
# and we want to save before processes are spawned
420418
log_dir = trainer.log_dir # this broadcasts the directory
421419
assert log_dir is not None
422420
config_path = os.path.join(log_dir, self.config_filename)
@@ -437,18 +435,14 @@ def setup(self, trainer: Trainer, pl_module: LightningModule, stage: Optional[st
437435

438436
# save the file on rank 0
439437
if trainer.is_global_zero:
440-
# save only on rank zero to avoid race conditions on DDP.
438+
# save only on rank zero to avoid race conditions.
441439
# the `log_dir` needs to be created as we rely on the logger to do it usually
442440
# but it hasn't logged anything at this point
443441
fs.makedirs(log_dir, exist_ok=True)
444442
self.parser.save(
445443
self.config, config_path, skip_none=False, overwrite=self.overwrite, multifile=self.multifile
446444
)
447445

448-
def __reduce__(self) -> Tuple[Type["SaveConfigCallback"], Tuple, Dict]:
449-
# `ArgumentParser` is un-pickleable. Drop it
450-
return self.__class__, (None, self.config, self.config_filename), {}
451-
452446

453447
class LightningCLI:
454448
"""Implementation of a configurable command line tool for pytorch-lightning."""

tests/utilities/test_cli.py

Lines changed: 13 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@
5050
SaveConfigCallback,
5151
)
5252
from pytorch_lightning.utilities.exceptions import MisconfigurationException
53-
from pytorch_lightning.utilities.imports import _TORCHVISION_AVAILABLE
53+
from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_8, _TORCHVISION_AVAILABLE
5454
from tests.helpers import BoringDataModule, BoringModel
5555
from tests.helpers.runif import RunIf
5656
from tests.helpers.utils import no_warning_call
@@ -576,21 +576,17 @@ def on_fit_start(self):
576576
raise MisconfigurationException("Error on fit start")
577577

578578

579+
@RunIf(skip_windows=True)
579580
@pytest.mark.parametrize("logger", (False, True))
580-
@pytest.mark.parametrize(
581-
"trainer_kwargs",
582-
(
583-
# dict(strategy="ddp_spawn")
584-
# dict(strategy="ddp")
585-
# the previous accl_conn will choose singleDeviceStrategy for both strategy=ddp/ddp_spawn
586-
# TODO revisit this test as it never worked with DDP or DDPSpawn
587-
dict(strategy="single_device"),
588-
pytest.param({"tpu_cores": 1}, marks=RunIf(tpu=True)),
589-
),
590-
)
591-
def test_cli_distributed_save_config_callback(tmpdir, logger, trainer_kwargs):
581+
@pytest.mark.parametrize("strategy", ("ddp_spawn", "ddp"))
582+
def test_cli_distributed_save_config_callback(tmpdir, logger, strategy):
583+
if _TORCH_GREATER_EQUAL_1_8:
584+
from torch.multiprocessing import ProcessRaisedException
585+
else:
586+
ProcessRaisedException = Exception
587+
592588
with mock.patch("sys.argv", ["any.py", "fit"]), pytest.raises(
593-
MisconfigurationException, match=r"Error on fit start"
589+
(MisconfigurationException, ProcessRaisedException), match=r"Error on fit start"
594590
):
595591
LightningCLI(
596592
EarlyExitTestModel,
@@ -599,7 +595,9 @@ def test_cli_distributed_save_config_callback(tmpdir, logger, trainer_kwargs):
599595
"logger": logger,
600596
"max_steps": 1,
601597
"max_epochs": 1,
602-
**trainer_kwargs,
598+
"strategy": strategy,
599+
"accelerator": "auto",
600+
"devices": 1,
603601
},
604602
)
605603
if logger:

0 commit comments

Comments
 (0)