Skip to content

Commit a013d79

Browse files
committed
Add back to TPU spawn
1 parent 1bcb57e commit a013d79

File tree

2 files changed

+36
-0
lines changed

2 files changed

+36
-0
lines changed

pytorch_lightning/plugins/training_type/tpu_spawn.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from torch.utils.data import DataLoader
2525

2626
import pytorch_lightning as pl
27+
from pytorch_lightning.loggers import LoggerCollection, TensorBoardLogger
2728
from pytorch_lightning.overrides import LightningDistributedModule
2829
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
2930
from pytorch_lightning.plugins.io.xla_plugin import XLACheckpointIO
@@ -304,8 +305,17 @@ def start_training(self, trainer: "pl.Trainer") -> None:
304305
# todo: precision pluging is call in accelerator setup and should be moved
305306
if "XLA_USE_BF16" in os.environ:
306307
del os.environ["XLA_USE_BF16"]
308+
self._clean_logger(trainer)
307309
return super().start_training(trainer)
308310

311+
def start_evaluating(self, trainer: "pl.Trainer") -> None:
312+
self._clean_logger(trainer)
313+
return super().start_evaluating(trainer)
314+
315+
def start_predicting(self, trainer: "pl.Trainer") -> None:
316+
self._clean_logger(trainer)
317+
return super().start_predicting(trainer)
318+
309319
def training_step(self, *args, **kwargs):
310320
return self.model(*args, **kwargs)
311321

@@ -381,3 +391,13 @@ def checkpoint_io(self) -> CheckpointIO:
381391
@checkpoint_io.setter
382392
def checkpoint_io(self, plugin: CheckpointIO) -> None:
383393
raise MisconfigurationException("TPU Spawn Plugin currently does not support custom checkpoint plugins.")
394+
395+
@staticmethod
396+
def _clean_logger(trainer: "pl.Trainer") -> None:
397+
loggers = trainer.logger._logger_iterable if isinstance(trainer.logger, LoggerCollection) else [trainer.logger]
398+
for logger in loggers:
399+
if isinstance(logger, TensorBoardLogger) and logger._experiment is not None:
400+
# the experiment class of `TensorBoard` holds a multiprocessing queue which can make ours hang.
401+
# we want to make sure these are closed before we spawn our own threads.
402+
# assuming nothing else references the experiment object, python should instantly `__del__` it.
403+
logger._experiment = None

tests/plugins/test_tpu_spawn.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from torch.utils.data import DataLoader
2121

2222
from pytorch_lightning import Trainer
23+
from pytorch_lightning.loggers import LoggerCollection, TensorBoardLogger
2324
from pytorch_lightning.plugins.training_type import TPUSpawnPlugin
2425
from pytorch_lightning.utilities.exceptions import MisconfigurationException
2526
from tests.helpers.boring_model import BoringModel, RandomDataset
@@ -102,3 +103,18 @@ def test_model_tpu_one_core():
102103
model = BoringModelTPU()
103104
trainer.fit(model)
104105
assert "PT_XLA_DEBUG" not in os.environ
106+
107+
108+
@RunIf(tpu=True)
109+
@pytest.mark.parametrize("use_list", [False, True])
110+
def test_tensorboard_ddp_spawn_cleanup(use_list, tmpdir):
111+
tensorboard_logger = TensorBoardLogger(save_dir=tmpdir)
112+
assert tensorboard_logger._experiment is None
113+
tensorboard_logger.experiment # this property access will create the experiment
114+
assert tensorboard_logger._experiment is not None
115+
logger = [tensorboard_logger] if use_list else tensorboard_logger
116+
trainer = Trainer(strategy="ddp_spawn", accelerator="tpu", devices="auto", logger=logger)
117+
trainer.training_type_plugin._clean_logger(trainer)
118+
if use_list:
119+
assert isinstance(trainer.logger, LoggerCollection)
120+
assert tensorboard_logger._experiment is None

0 commit comments

Comments
 (0)