Skip to content

Commit 8e1b9b3

Browse files
authored
Skip hanging spawn tests (#10838)
* Skip hanging spawn tests * Docstring fix * Add back to TPU spawn
1 parent 38ed26e commit 8e1b9b3

File tree

7 files changed

+49
-34
lines changed

7 files changed

+49
-34
lines changed

pytorch_lightning/plugins/training_type/ddp_spawn.py

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
from torch.nn.parallel.distributed import DistributedDataParallel
2626

2727
import pytorch_lightning as pl
28-
from pytorch_lightning.loggers import LoggerCollection, TensorBoardLogger
2928
from pytorch_lightning.overrides import LightningDistributedModule
3029
from pytorch_lightning.overrides.distributed import prepare_for_backward
3130
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
@@ -149,17 +148,14 @@ def get_mp_spawn_kwargs(self, trainer: Optional["pl.Trainer"] = None) -> Dict[st
149148
return {"nprocs": self.num_processes}
150149

151150
def start_training(self, trainer: "pl.Trainer") -> None:
152-
self._clean_logger(trainer)
153151
self.spawn(self.new_process, trainer, self.mp_queue, return_result=False)
154152
# reset optimizers, since main process is never used for training and thus does not have a valid optim state
155153
trainer.optimizers = []
156154

157155
def start_evaluating(self, trainer: "pl.Trainer") -> None:
158-
self._clean_logger(trainer)
159156
self.spawn(self.new_process, trainer, self.mp_queue, return_result=False)
160157

161158
def start_predicting(self, trainer: "pl.Trainer") -> None:
162-
self._clean_logger(trainer)
163159
self.spawn(self.new_process, trainer, self.mp_queue, return_result=False)
164160

165161
def spawn(self, function: Callable, *args: Any, return_result: bool = True, **kwargs: Any) -> Optional[Any]:
@@ -420,16 +416,3 @@ def teardown(self) -> None:
420416
self.lightning_module.cpu()
421417
# clean up memory
422418
torch.cuda.empty_cache()
423-
424-
@staticmethod
425-
def _clean_logger(trainer: "pl.Trainer") -> None:
426-
loggers = trainer.logger._logger_iterable if isinstance(trainer.logger, LoggerCollection) else [trainer.logger]
427-
for logger in loggers:
428-
if isinstance(logger, TensorBoardLogger) and logger._experiment is not None:
429-
rank_zero_warn(
430-
"When using `ddp_spawn`, the `TensorBoardLogger` experiment should be `None`. Setting it to `None`."
431-
)
432-
# the experiment class of `TensorBoard` holds a multiprocessing queue which can make ours hang.
433-
# we want to make sure these are closed before we spawn our own threads.
434-
# assuming nothing else references the experiment object, python should instantly `__del__` it.
435-
logger._experiment = None

pytorch_lightning/plugins/training_type/tpu_spawn.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from torch.utils.data import DataLoader
2525

2626
import pytorch_lightning as pl
27+
from pytorch_lightning.loggers import LoggerCollection, TensorBoardLogger
2728
from pytorch_lightning.overrides import LightningDistributedModule
2829
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
2930
from pytorch_lightning.plugins.io.xla_plugin import XLACheckpointIO
@@ -304,8 +305,17 @@ def start_training(self, trainer: "pl.Trainer") -> None:
304305
# todo: precision pluging is call in accelerator setup and should be moved
305306
if "XLA_USE_BF16" in os.environ:
306307
del os.environ["XLA_USE_BF16"]
308+
self._clean_logger(trainer)
307309
return super().start_training(trainer)
308310

311+
def start_evaluating(self, trainer: "pl.Trainer") -> None:
312+
self._clean_logger(trainer)
313+
return super().start_evaluating(trainer)
314+
315+
def start_predicting(self, trainer: "pl.Trainer") -> None:
316+
self._clean_logger(trainer)
317+
return super().start_predicting(trainer)
318+
309319
def training_step(self, *args, **kwargs):
310320
return self.model(*args, **kwargs)
311321

@@ -381,3 +391,13 @@ def checkpoint_io(self) -> CheckpointIO:
381391
@checkpoint_io.setter
382392
def checkpoint_io(self, plugin: CheckpointIO) -> None:
383393
raise MisconfigurationException("TPU Spawn Plugin currently does not support custom checkpoint plugins.")
394+
395+
@staticmethod
396+
def _clean_logger(trainer: "pl.Trainer") -> None:
397+
loggers = trainer.logger._logger_iterable if isinstance(trainer.logger, LoggerCollection) else [trainer.logger]
398+
for logger in loggers:
399+
if isinstance(logger, TensorBoardLogger) and logger._experiment is not None:
400+
# the experiment class of `TensorBoard` holds a multiprocessing queue which can make ours hang.
401+
# we want to make sure these are closed before we spawn our own threads.
402+
# assuming nothing else references the experiment object, python should instantly `__del__` it.
403+
logger._experiment = None

tests/helpers/runif.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ def __new__(
7171
deepspeed: bool = False,
7272
rich: bool = False,
7373
skip_49370: bool = False,
74+
skip_hanging_spawn: bool = False,
7475
omegaconf: bool = False,
7576
**kwargs,
7677
):
@@ -94,6 +95,7 @@ def __new__(
9495
deepspeed: Require that microsoft/DeepSpeed is installed.
9596
rich: Require that willmcgugan/rich is installed.
9697
skip_49370: Skip the test as it's impacted by https://github.com/pytorch/pytorch/issues/49370.
98+
skip_hanging_spawn: Skip the test as it's impacted by hanging loggers on spawn.
9799
omegaconf: Require that omry/omegaconf is installed.
98100
**kwargs: Any :class:`pytest.mark.skipif` keyword arguments.
99101
"""
@@ -180,6 +182,15 @@ def __new__(
180182
conditions.append(ge_3_9 and old_torch)
181183
reasons.append("Impacted by https://github.com/pytorch/pytorch/issues/49370")
182184

185+
if skip_hanging_spawn:
186+
# strategy=ddp_spawn, accelerator=cpu, python>=3.8, torch<1.9 does not work
187+
py_version = f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}"
188+
ge_3_8 = Version(py_version) >= Version("3.8")
189+
torch_version = get_distribution("torch").version
190+
old_torch = Version(torch_version) < Version("1.9")
191+
conditions.append(ge_3_8 and old_torch)
192+
reasons.append("Impacted by hanging DDP spawn")
193+
183194
if omegaconf:
184195
conditions.append(not _OMEGACONF_AVAILABLE)
185196
reasons.append("omegaconf")

tests/loggers/test_all.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -329,7 +329,7 @@ def on_train_batch_start(self, trainer, pl_module, batch, batch_idx):
329329
assert pl_module.logger.experiment.something(foo="bar") is None
330330

331331

332-
@RunIf(skip_windows=True, skip_49370=True)
332+
@RunIf(skip_windows=True, skip_49370=True, skip_hanging_spawn=True)
333333
@pytest.mark.parametrize("logger_class", [CometLogger, CSVLogger, MLFlowLogger, TensorBoardLogger, TestTubeLogger])
334334
def test_logger_created_on_rank_zero_only(tmpdir, monkeypatch, logger_class):
335335
"""Test that loggers get replaced by dummy loggers on global rank > 0."""

tests/loggers/test_tensorboard.py

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424

2525
from pytorch_lightning import Trainer
2626
from pytorch_lightning.loggers import TensorBoardLogger
27-
from pytorch_lightning.loggers.base import LoggerCollection
2827
from pytorch_lightning.utilities.imports import _compare_version, _OMEGACONF_AVAILABLE
2928
from tests.helpers import BoringModel
3029
from tests.helpers.runif import RunIf
@@ -335,17 +334,3 @@ def test_tensorboard_missing_folder_warning(tmpdir, caplog):
335334
assert logger.version == 0
336335

337336
assert "Missing logger folder:" in caplog.text
338-
339-
340-
@pytest.mark.parametrize("use_list", [False, True])
341-
def test_tensorboard_ddp_spawn_cleanup(use_list, tmpdir):
342-
tensorboard_logger = TensorBoardLogger(save_dir=tmpdir)
343-
assert tensorboard_logger._experiment is None
344-
tensorboard_logger.experiment # this property access will create the experiment
345-
assert tensorboard_logger._experiment is not None
346-
logger = [tensorboard_logger] if use_list else tensorboard_logger
347-
trainer = Trainer(strategy="ddp_spawn", devices=2, accelerator="auto", logger=logger)
348-
trainer.training_type_plugin._clean_logger(trainer)
349-
if use_list:
350-
assert isinstance(trainer.logger, LoggerCollection)
351-
assert tensorboard_logger._experiment is None

tests/plugins/test_tpu_spawn.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from torch.utils.data import DataLoader
2121

2222
from pytorch_lightning import Trainer
23+
from pytorch_lightning.loggers import LoggerCollection, TensorBoardLogger
2324
from pytorch_lightning.plugins.training_type import TPUSpawnPlugin
2425
from pytorch_lightning.utilities.exceptions import MisconfigurationException
2526
from tests.helpers.boring_model import BoringModel, RandomDataset
@@ -102,3 +103,18 @@ def test_model_tpu_one_core():
102103
model = BoringModelTPU()
103104
trainer.fit(model)
104105
assert "PT_XLA_DEBUG" not in os.environ
106+
107+
108+
@RunIf(tpu=True)
109+
@pytest.mark.parametrize("use_list", [False, True])
110+
def test_tensorboard_ddp_spawn_cleanup(use_list, tmpdir):
111+
tensorboard_logger = TensorBoardLogger(save_dir=tmpdir)
112+
assert tensorboard_logger._experiment is None
113+
tensorboard_logger.experiment # this property access will create the experiment
114+
assert tensorboard_logger._experiment is not None
115+
logger = [tensorboard_logger] if use_list else tensorboard_logger
116+
trainer = Trainer(strategy="ddp_spawn", accelerator="tpu", devices="auto", logger=logger)
117+
trainer.training_type_plugin._clean_logger(trainer)
118+
if use_list:
119+
assert isinstance(trainer.logger, LoggerCollection)
120+
assert tensorboard_logger._experiment is None

tests/utilities/test_all_gather_grad.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def _test_all_gather_ddp(rank, world_size):
5454
assert torch.allclose(grad2, tensor2.grad)
5555

5656

57-
@RunIf(skip_windows=True, skip_49370=True)
57+
@RunIf(skip_windows=True, skip_49370=True, skip_hanging_spawn=True)
5858
def test_all_gather_ddp_spawn():
5959
world_size = 3
6060
torch.multiprocessing.spawn(_test_all_gather_ddp, args=(world_size,), nprocs=world_size)

0 commit comments

Comments
 (0)