Skip to content

Commit 15bbfac

Browse files
tchatonawaelchli
authored andcommitted
Delete TensorBoardLogger experiment before spawning the processes. (#10777)
1 parent c16269b commit 15bbfac

File tree

4 files changed

+35
-9
lines changed

4 files changed

+35
-9
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
3232
- Fixed `_compare_version` for python packages ([#10762](https://github.com/PyTorchLightning/pytorch-lightning/pull/10762))
3333

3434

35+
- Fixed TensorBoardLogger `SummaryWriter` not close before spawning the processes ([#10777](https://github.com/PyTorchLightning/pytorch-lightning/pull/10777))
36+
37+
3538
## [1.5.2] - 2021-11-16
3639

3740
### Fixed

pytorch_lightning/plugins/training_type/ddp_spawn.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from torch.nn.parallel.distributed import DistributedDataParallel
2626

2727
import pytorch_lightning as pl
28+
from pytorch_lightning.loggers import LoggerCollection, TensorBoardLogger
2829
from pytorch_lightning.overrides import LightningDistributedModule
2930
from pytorch_lightning.overrides.distributed import prepare_for_backward
3031
from pytorch_lightning.overrides.torch_distributed import broadcast_object_list
@@ -170,14 +171,17 @@ def get_mp_spawn_kwargs(self, trainer: Optional["pl.Trainer"] = None) -> Dict[st
170171
return {"nprocs": self.num_processes}
171172

172173
def start_training(self, trainer: "pl.Trainer") -> None:
174+
self._clean_logger(trainer)
173175
self.spawn(self.new_process, trainer, self.mp_queue, return_result=False)
174176
# reset optimizers, since main process is never used for training and thus does not have a valid optim state
175177
trainer.optimizers = []
176178

177179
def start_evaluating(self, trainer: "pl.Trainer") -> None:
180+
self._clean_logger(trainer)
178181
self.spawn(self.new_process, trainer, self.mp_queue, return_result=False)
179182

180183
def start_predicting(self, trainer: "pl.Trainer") -> None:
184+
self._clean_logger(trainer)
181185
self.spawn(self.new_process, trainer, self.mp_queue, return_result=False)
182186

183187
def spawn(self, function: Callable, *args: Any, return_result: bool = True, **kwargs: Any) -> Optional[Any]:
@@ -440,3 +444,16 @@ def teardown(self) -> None:
440444
self.lightning_module.cpu()
441445
# clean up memory
442446
torch.cuda.empty_cache()
447+
448+
@staticmethod
449+
def _clean_logger(trainer: "pl.Trainer") -> None:
450+
loggers = trainer.logger._logger_iterable if isinstance(trainer.logger, LoggerCollection) else [trainer.logger]
451+
for logger in loggers:
452+
if isinstance(logger, TensorBoardLogger) and logger._experiment is not None:
453+
rank_zero_warn(
454+
"When using `ddp_spawn`, the `TensorBoardLogger` experiment should be `None`. Setting it to `None`."
455+
)
456+
# the experiment class of `TensorBoard` holds a multiprocessing queue which can make ours hang.
457+
# we want to make sure these are closed before we spawn our own threads.
458+
# assuming nothing else references the experiment object, python should instantly `__del__` it.
459+
logger._experiment = None

pytorch_lightning/plugins/training_type/tpu_spawn.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -254,10 +254,6 @@ def reduce(self, output, group: Optional[Any] = None, reduce_op: Optional[Union[
254254

255255
return output
256256

257-
def _close_logger(self, trainer) -> None:
258-
if trainer.logger is not None:
259-
trainer.logger.finalize("success")
260-
261257
def get_mp_spawn_kwargs(self, trainer: Optional["pl.Trainer"] = None) -> Dict[str, Any]:
262258
return {
263259
"nprocs": len(self.parallel_devices),
@@ -293,13 +289,8 @@ def start_training(self, trainer: "pl.Trainer") -> None:
293289
# todo: precision pluging is call in accelerator setup and should be moved
294290
if "XLA_USE_BF16" in os.environ:
295291
del os.environ["XLA_USE_BF16"]
296-
self._close_logger(trainer)
297292
return super().start_training(trainer)
298293

299-
def start_evaluating(self, trainer: "pl.Trainer") -> None:
300-
self._close_logger(trainer)
301-
return super().start_evaluating(trainer)
302-
303294
def training_step(self, *args, **kwargs):
304295
return self.model(*args, **kwargs)
305296

tests/loggers/test_tensorboard.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525

2626
from pytorch_lightning import Trainer
2727
from pytorch_lightning.loggers import TensorBoardLogger
28+
from pytorch_lightning.loggers.base import LoggerCollection
2829
from pytorch_lightning.utilities.imports import _compare_version
2930
from tests.helpers import BoringModel
3031

@@ -332,3 +333,17 @@ def test_tensorboard_missing_folder_warning(tmpdir, caplog):
332333
assert logger.version == 0
333334

334335
assert "Missing logger folder:" in caplog.text
336+
337+
338+
@pytest.mark.parametrize("use_list", [False, True])
339+
def test_tensorboard_ddp_spawn_cleanup(use_list, tmpdir):
340+
tensorboard_logger = TensorBoardLogger(save_dir=tmpdir)
341+
assert tensorboard_logger._experiment is None
342+
tensorboard_logger.experiment # this property access will create the experiment
343+
assert tensorboard_logger._experiment is not None
344+
logger = [tensorboard_logger] if use_list else tensorboard_logger
345+
trainer = Trainer(strategy="ddp_spawn", devices=2, accelerator="auto", logger=logger)
346+
trainer.training_type_plugin._clean_logger(trainer)
347+
if use_list:
348+
assert isinstance(trainer.logger, LoggerCollection)
349+
assert tensorboard_logger._experiment is None

0 commit comments

Comments
 (0)