|
24 | 24 | from torch.utils.data import DataLoader
|
25 | 25 |
|
26 | 26 | import pytorch_lightning as pl
|
| 27 | +from pytorch_lightning.loggers import LoggerCollection, TensorBoardLogger |
27 | 28 | from pytorch_lightning.overrides import LightningDistributedModule
|
28 | 29 | from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
|
29 | 30 | from pytorch_lightning.plugins.io.xla_plugin import XLACheckpointIO
|
@@ -304,8 +305,17 @@ def start_training(self, trainer: "pl.Trainer") -> None:
|
304 | 305 | # todo: precision pluging is call in accelerator setup and should be moved
|
305 | 306 | if "XLA_USE_BF16" in os.environ:
|
306 | 307 | del os.environ["XLA_USE_BF16"]
|
| 308 | + self._clean_logger(trainer) |
307 | 309 | return super().start_training(trainer)
|
308 | 310 |
|
| 311 | + def start_evaluating(self, trainer: "pl.Trainer") -> None: |
| 312 | + self._clean_logger(trainer) |
| 313 | + return super().start_evaluating(trainer) |
| 314 | + |
| 315 | + def start_predicting(self, trainer: "pl.Trainer") -> None: |
| 316 | + self._clean_logger(trainer) |
| 317 | + return super().start_predicting(trainer) |
| 318 | + |
309 | 319 | def training_step(self, *args, **kwargs):
|
310 | 320 | return self.model(*args, **kwargs)
|
311 | 321 |
|
@@ -381,3 +391,13 @@ def checkpoint_io(self) -> CheckpointIO:
|
381 | 391 | @checkpoint_io.setter
|
382 | 392 | def checkpoint_io(self, plugin: CheckpointIO) -> None:
|
383 | 393 | raise MisconfigurationException("TPU Spawn Plugin currently does not support custom checkpoint plugins.")
|
| 394 | + |
| 395 | + @staticmethod |
| 396 | + def _clean_logger(trainer: "pl.Trainer") -> None: |
| 397 | + loggers = trainer.logger._logger_iterable if isinstance(trainer.logger, LoggerCollection) else [trainer.logger] |
| 398 | + for logger in loggers: |
| 399 | + if isinstance(logger, TensorBoardLogger) and logger._experiment is not None: |
| 400 | + # the experiment class of `TensorBoard` holds a multiprocessing queue which can make ours hang. |
| 401 | + # we want to make sure these are closed before we spawn our own threads. |
| 402 | + # assuming nothing else references the experiment object, python should instantly `__del__` it. |
| 403 | + logger._experiment = None |
0 commit comments