Skip to content

Commit a418140

Browse files
awaelchlicarmocca
andcommitted
Skip hanging spawn tests (#10838)
Co-authored-by: Carlos Mocholi <[email protected]>
1 parent 319c8d3 commit a418140

File tree

7 files changed

+49
-34
lines changed

7 files changed

+49
-34
lines changed

pytorch_lightning/plugins/training_type/ddp_spawn.py

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
from torch.nn.parallel.distributed import DistributedDataParallel
2626

2727
import pytorch_lightning as pl
28-
from pytorch_lightning.loggers import LoggerCollection, TensorBoardLogger
2928
from pytorch_lightning.overrides import LightningDistributedModule
3029
from pytorch_lightning.overrides.distributed import prepare_for_backward
3130
from pytorch_lightning.overrides.torch_distributed import broadcast_object_list
@@ -171,17 +170,14 @@ def get_mp_spawn_kwargs(self, trainer: Optional["pl.Trainer"] = None) -> Dict[st
171170
return {"nprocs": self.num_processes}
172171

173172
def start_training(self, trainer: "pl.Trainer") -> None:
174-
self._clean_logger(trainer)
175173
self.spawn(self.new_process, trainer, self.mp_queue, return_result=False)
176174
# reset optimizers, since main process is never used for training and thus does not have a valid optim state
177175
trainer.optimizers = []
178176

179177
def start_evaluating(self, trainer: "pl.Trainer") -> None:
180-
self._clean_logger(trainer)
181178
self.spawn(self.new_process, trainer, self.mp_queue, return_result=False)
182179

183180
def start_predicting(self, trainer: "pl.Trainer") -> None:
184-
self._clean_logger(trainer)
185181
self.spawn(self.new_process, trainer, self.mp_queue, return_result=False)
186182

187183
def spawn(self, function: Callable, *args: Any, return_result: bool = True, **kwargs: Any) -> Optional[Any]:
@@ -444,16 +440,3 @@ def teardown(self) -> None:
444440
self.lightning_module.cpu()
445441
# clean up memory
446442
torch.cuda.empty_cache()
447-
448-
@staticmethod
449-
def _clean_logger(trainer: "pl.Trainer") -> None:
450-
loggers = trainer.logger._logger_iterable if isinstance(trainer.logger, LoggerCollection) else [trainer.logger]
451-
for logger in loggers:
452-
if isinstance(logger, TensorBoardLogger) and logger._experiment is not None:
453-
rank_zero_warn(
454-
"When using `ddp_spawn`, the `TensorBoardLogger` experiment should be `None`. Setting it to `None`."
455-
)
456-
# the experiment class of `TensorBoard` holds a multiprocessing queue which can make ours hang.
457-
# we want to make sure these are closed before we spawn our own threads.
458-
# assuming nothing else references the experiment object, python should instantly `__del__` it.
459-
logger._experiment = None

pytorch_lightning/plugins/training_type/tpu_spawn.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from torch.utils.data import DataLoader
2525

2626
import pytorch_lightning as pl
27+
from pytorch_lightning.loggers import LoggerCollection, TensorBoardLogger
2728
from pytorch_lightning.overrides import LightningDistributedModule
2829
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
2930
from pytorch_lightning.plugins.io.xla_plugin import XLACheckpointIO
@@ -289,8 +290,17 @@ def start_training(self, trainer: "pl.Trainer") -> None:
289290
# todo: precision pluging is call in accelerator setup and should be moved
290291
if "XLA_USE_BF16" in os.environ:
291292
del os.environ["XLA_USE_BF16"]
293+
self._clean_logger(trainer)
292294
return super().start_training(trainer)
293295

296+
def start_evaluating(self, trainer: "pl.Trainer") -> None:
297+
self._clean_logger(trainer)
298+
return super().start_evaluating(trainer)
299+
300+
def start_predicting(self, trainer: "pl.Trainer") -> None:
301+
self._clean_logger(trainer)
302+
return super().start_predicting(trainer)
303+
294304
def training_step(self, *args, **kwargs):
295305
return self.model(*args, **kwargs)
296306

@@ -366,3 +376,13 @@ def checkpoint_io(self) -> CheckpointIO:
366376
@checkpoint_io.setter
367377
def checkpoint_io(self, plugin: CheckpointIO) -> None:
368378
raise MisconfigurationException("TPU Spawn Plugin currently does not support custom checkpoint plugins.")
379+
380+
@staticmethod
381+
def _clean_logger(trainer: "pl.Trainer") -> None:
382+
loggers = trainer.logger._logger_iterable if isinstance(trainer.logger, LoggerCollection) else [trainer.logger]
383+
for logger in loggers:
384+
if isinstance(logger, TensorBoardLogger) and logger._experiment is not None:
385+
# the experiment class of `TensorBoard` holds a multiprocessing queue which can make ours hang.
386+
# we want to make sure these are closed before we spawn our own threads.
387+
# assuming nothing else references the experiment object, python should instantly `__del__` it.
388+
logger._experiment = None

tests/helpers/runif.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ def __new__(
7171
deepspeed: bool = False,
7272
rich: bool = False,
7373
skip_49370: bool = False,
74+
skip_hanging_spawn: bool = False,
7475
**kwargs,
7576
):
7677
"""
@@ -93,6 +94,7 @@ def __new__(
9394
deepspeed: if `deepspeed` module is required to run the test
9495
rich: if `rich` module is required to run the test
9596
skip_49370: Skip the test as it's impacted by https://github.com/pytorch/pytorch/issues/49370.
97+
skip_hanging_spawn: Skip the test as it's impacted by hanging loggers on spawn.
9698
kwargs: native pytest.mark.skipif keyword arguments
9799
"""
98100
conditions = []
@@ -178,6 +180,15 @@ def __new__(
178180
conditions.append(ge_3_9 and old_torch)
179181
reasons.append("Impacted by https://github.com/pytorch/pytorch/issues/49370")
180182

183+
if skip_hanging_spawn:
184+
# strategy=ddp_spawn, accelerator=cpu, python>=3.8, torch<1.9 does not work
185+
py_version = f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}"
186+
ge_3_8 = Version(py_version) >= Version("3.8")
187+
torch_version = get_distribution("torch").version
188+
old_torch = Version(torch_version) < Version("1.9")
189+
conditions.append(ge_3_8 and old_torch)
190+
reasons.append("Impacted by hanging DDP spawn")
191+
181192
reasons = [rs for cond, rs in zip(conditions, reasons) if cond]
182193
return pytest.mark.skipif(
183194
*args, condition=any(conditions), reason=f"Requires: [{' + '.join(reasons)}]", **kwargs

tests/loggers/test_all.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -321,7 +321,7 @@ def on_train_batch_start(self, trainer, pl_module, batch, batch_idx):
321321
assert pl_module.logger.experiment.something(foo="bar") is None
322322

323323

324-
@RunIf(skip_windows=True, skip_49370=True)
324+
@RunIf(skip_windows=True, skip_49370=True, skip_hanging_spawn=True)
325325
@pytest.mark.parametrize("logger_class", [CometLogger, CSVLogger, MLFlowLogger, TensorBoardLogger, TestTubeLogger])
326326
def test_logger_created_on_rank_zero_only(tmpdir, monkeypatch, logger_class):
327327
"""Test that loggers get replaced by dummy loggers on global rank > 0."""

tests/loggers/test_tensorboard.py

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525

2626
from pytorch_lightning import Trainer
2727
from pytorch_lightning.loggers import TensorBoardLogger
28-
from pytorch_lightning.loggers.base import LoggerCollection
2928
from pytorch_lightning.utilities.imports import _compare_version
3029
from tests.helpers import BoringModel
3130

@@ -333,17 +332,3 @@ def test_tensorboard_missing_folder_warning(tmpdir, caplog):
333332
assert logger.version == 0
334333

335334
assert "Missing logger folder:" in caplog.text
336-
337-
338-
@pytest.mark.parametrize("use_list", [False, True])
339-
def test_tensorboard_ddp_spawn_cleanup(use_list, tmpdir):
340-
tensorboard_logger = TensorBoardLogger(save_dir=tmpdir)
341-
assert tensorboard_logger._experiment is None
342-
tensorboard_logger.experiment # this property access will create the experiment
343-
assert tensorboard_logger._experiment is not None
344-
logger = [tensorboard_logger] if use_list else tensorboard_logger
345-
trainer = Trainer(strategy="ddp_spawn", devices=2, accelerator="auto", logger=logger)
346-
trainer.training_type_plugin._clean_logger(trainer)
347-
if use_list:
348-
assert isinstance(trainer.logger, LoggerCollection)
349-
assert tensorboard_logger._experiment is None

tests/plugins/test_tpu_spawn.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from torch.utils.data import DataLoader
2121

2222
from pytorch_lightning import Trainer
23+
from pytorch_lightning.loggers import LoggerCollection, TensorBoardLogger
2324
from pytorch_lightning.plugins.training_type import TPUSpawnPlugin
2425
from pytorch_lightning.utilities.exceptions import MisconfigurationException
2526
from tests.helpers.boring_model import BoringModel, RandomDataset
@@ -102,3 +103,18 @@ def test_model_tpu_one_core():
102103
model = BoringModelTPU()
103104
trainer.fit(model)
104105
assert "PT_XLA_DEBUG" not in os.environ
106+
107+
108+
@RunIf(tpu=True)
109+
@pytest.mark.parametrize("use_list", [False, True])
110+
def test_tensorboard_ddp_spawn_cleanup(use_list, tmpdir):
111+
tensorboard_logger = TensorBoardLogger(save_dir=tmpdir)
112+
assert tensorboard_logger._experiment is None
113+
tensorboard_logger.experiment # this property access will create the experiment
114+
assert tensorboard_logger._experiment is not None
115+
logger = [tensorboard_logger] if use_list else tensorboard_logger
116+
trainer = Trainer(strategy="ddp_spawn", accelerator="tpu", devices="auto", logger=logger)
117+
trainer.training_type_plugin._clean_logger(trainer)
118+
if use_list:
119+
assert isinstance(trainer.logger, LoggerCollection)
120+
assert tensorboard_logger._experiment is None

tests/utilities/test_all_gather_grad.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def _test_all_gather_ddp(rank, world_size):
4141
assert torch.allclose(grad2, tensor2.grad)
4242

4343

44-
@RunIf(skip_windows=True, skip_49370=True)
44+
@RunIf(skip_windows=True, skip_49370=True, skip_hanging_spawn=True)
4545
def test_all_gather_ddp_spawn():
4646
world_size = 3
4747
torch.multiprocessing.spawn(_test_all_gather_ddp, args=(world_size,), nprocs=world_size)

0 commit comments

Comments
 (0)