Skip to content

Commit a970810

Browse files
authored
Lazy import tensorboard (#15762)
1 parent 952b64b commit a970810

File tree

4 files changed

+59
-19
lines changed

4 files changed

+59
-19
lines changed

src/pytorch_lightning/loggers/tensorboard.py

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,10 @@
1919
import logging
2020
import os
2121
from argparse import Namespace
22-
from typing import Any, Dict, Mapping, Optional, Union
22+
from typing import Any, Dict, Mapping, Optional, TYPE_CHECKING, Union
2323

2424
import numpy as np
2525
from lightning_utilities.core.imports import RequirementCache
26-
from tensorboardX import SummaryWriter
27-
from tensorboardX.summary import hparams
2826
from torch import Tensor
2927

3028
import pytorch_lightning as pl
@@ -40,6 +38,13 @@
4038
log = logging.getLogger(__name__)
4139

4240
_TENSORBOARD_AVAILABLE = RequirementCache("tensorboard")
41+
_TENSORBOARDX_AVAILABLE = RequirementCache("tensorboardX")
42+
if TYPE_CHECKING:
43+
# assumes at least one will be installed when type checking
44+
if _TENSORBOARD_AVAILABLE:
45+
from torch.utils.tensorboard import SummaryWriter
46+
else:
47+
from tensorboardX import SummaryWriter # type: ignore[no-redef]
4348

4449
if _OMEGACONF_AVAILABLE:
4550
from omegaconf import Container, OmegaConf
@@ -109,6 +114,10 @@ def __init__(
109114
sub_dir: Optional[_PATH] = None,
110115
**kwargs: Any,
111116
):
117+
if not _TENSORBOARD_AVAILABLE and not _TENSORBOARDX_AVAILABLE:
118+
raise ModuleNotFoundError(
119+
"Neither `tensorboard` nor `tensorboardX` is available. Try `pip install`ing either."
120+
)
112121
super().__init__()
113122
save_dir = os.fspath(save_dir)
114123
self._save_dir = save_dir
@@ -172,7 +181,7 @@ def sub_dir(self) -> Optional[str]:
172181

173182
@property
174183
@rank_zero_experiment
175-
def experiment(self) -> SummaryWriter:
184+
def experiment(self) -> "SummaryWriter":
176185
r"""
177186
Actual tensorboard object. To use TensorBoard features in your
178187
:class:`~pytorch_lightning.core.module.LightningModule` do the following.
@@ -188,6 +197,12 @@ def experiment(self) -> SummaryWriter:
188197
assert rank_zero_only.rank == 0, "tried to init log dirs in non global_rank=0"
189198
if self.root_dir:
190199
self._fs.makedirs(self.root_dir, exist_ok=True)
200+
201+
if _TENSORBOARD_AVAILABLE:
202+
from torch.utils.tensorboard import SummaryWriter
203+
else:
204+
from tensorboardX import SummaryWriter # type: ignore[no-redef]
205+
191206
self._experiment = SummaryWriter(log_dir=self.log_dir, **self._kwargs)
192207
return self._experiment
193208

@@ -224,6 +239,12 @@ def log_hyperparams(
224239

225240
if metrics:
226241
self.log_metrics(metrics, 0)
242+
243+
if _TENSORBOARD_AVAILABLE:
244+
from torch.utils.tensorboard.summary import hparams
245+
else:
246+
from tensorboardX.summary import hparams # type: ignore[no-redef]
247+
227248
exp, ssi, sei = hparams(params, metrics)
228249
writer = self.experiment._get_file_writer()
229250
writer.add_summary(exp)

tests/tests_pytorch/loggers/test_all.py

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import inspect
1616
import pickle
1717
from unittest import mock
18-
from unittest.mock import ANY
18+
from unittest.mock import ANY, Mock
1919

2020
import pytest
2121
import torch
@@ -31,6 +31,7 @@
3131
WandbLogger,
3232
)
3333
from pytorch_lightning.loggers.logger import DummyExperiment
34+
from pytorch_lightning.loggers.tensorboard import _TENSORBOARD_AVAILABLE
3435
from tests_pytorch.helpers.runif import RunIf
3536
from tests_pytorch.loggers.test_comet import _patch_comet_atexit
3637
from tests_pytorch.loggers.test_mlflow import mock_mlflow_run_creation
@@ -300,10 +301,15 @@ def test_logger_with_prefix_all(tmpdir, monkeypatch):
300301
logger.experiment.__getitem__().log.assert_called_once_with(1.0)
301302

302303
# TensorBoard
303-
with mock.patch("pytorch_lightning.loggers.tensorboard.SummaryWriter"):
304-
logger = _instantiate_logger(TensorBoardLogger, save_dir=tmpdir, prefix=prefix)
305-
logger.log_metrics({"test": 1.0}, step=0)
306-
logger.experiment.add_scalar.assert_called_once_with("tmp-test", 1.0, 0)
304+
if _TENSORBOARD_AVAILABLE:
305+
import torch.utils.tensorboard as tb
306+
else:
307+
import tensorboardX as tb
308+
309+
monkeypatch.setattr(tb, "SummaryWriter", Mock())
310+
logger = _instantiate_logger(TensorBoardLogger, save_dir=tmpdir, prefix=prefix)
311+
logger.log_metrics({"test": 1.0}, step=0)
312+
logger.experiment.add_scalar.assert_called_once_with("tmp-test", 1.0, 0)
307313

308314
# WandB
309315
with mock.patch("pytorch_lightning.loggers.wandb.wandb") as wandb, mock.patch(
@@ -316,17 +322,22 @@ def test_logger_with_prefix_all(tmpdir, monkeypatch):
316322
logger.experiment.log.assert_called_once_with({"tmp-test": 1.0, "trainer/global_step": 0})
317323

318324

319-
def test_logger_default_name(tmpdir):
325+
def test_logger_default_name(tmpdir, monkeypatch):
320326
"""Test that the default logger name is lightning_logs."""
321327

322328
# CSV
323329
logger = CSVLogger(save_dir=tmpdir)
324330
assert logger.name == "lightning_logs"
325331

326332
# TensorBoard
327-
with mock.patch("pytorch_lightning.loggers.tensorboard.SummaryWriter"):
328-
logger = _instantiate_logger(TensorBoardLogger, save_dir=tmpdir)
329-
assert logger.name == "lightning_logs"
333+
if _TENSORBOARD_AVAILABLE:
334+
import torch.utils.tensorboard as tb
335+
else:
336+
import tensorboardX as tb
337+
338+
monkeypatch.setattr(tb, "SummaryWriter", Mock())
339+
logger = _instantiate_logger(TensorBoardLogger, save_dir=tmpdir)
340+
assert logger.name == "lightning_logs"
330341

331342
# MLflow
332343
with mock.patch("pytorch_lightning.loggers.mlflow.mlflow"), mock.patch(

tests/tests_pytorch/loggers/test_tensorboard.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import os
1616
from argparse import Namespace
1717
from unittest import mock
18+
from unittest.mock import Mock
1819

1920
import numpy as np
2021
import pytest
@@ -278,23 +279,28 @@ def training_step(self, *args):
278279
assert count_steps == model.indexes
279280

280281

281-
@mock.patch("pytorch_lightning.loggers.tensorboard.SummaryWriter")
282-
def test_tensorboard_finalize(summary_writer, tmpdir):
282+
def test_tensorboard_finalize(monkeypatch, tmpdir):
283283
"""Test that the SummaryWriter closes in finalize."""
284+
if _TENSORBOARD_AVAILABLE:
285+
import torch.utils.tensorboard as tb
286+
else:
287+
import tensorboardX as tb
288+
289+
monkeypatch.setattr(tb, "SummaryWriter", Mock())
284290
logger = TensorBoardLogger(save_dir=tmpdir)
285291
assert logger._experiment is None
286292
logger.finalize("any")
287293

288294
# no log calls, no experiment created -> nothing to flush
289-
summary_writer.assert_not_called()
295+
logger.experiment.assert_not_called()
290296

291297
logger = TensorBoardLogger(save_dir=tmpdir)
292298
logger.log_metrics({"flush_me": 11.1}) # trigger creation of an experiment
293299
logger.finalize("any")
294300

295301
# finalize flushes to experiment directory
296-
summary_writer().flush.assert_called()
297-
summary_writer().close.assert_called()
302+
logger.experiment.flush.assert_called()
303+
logger.experiment.close.assert_called()
298304

299305

300306
def test_tensorboard_save_hparams_to_yaml_once(tmpdir):

tests/tests_pytorch/test_cli.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1330,7 +1330,9 @@ def test_tensorboard_logger_init_args():
13301330
"TensorBoardLogger",
13311331
{
13321332
"save_dir": "tb", # Resolve from TensorBoardLogger.__init__
1333-
"comment": "tb", # Resolve from tensorboard.writer.SummaryWriter.__init__
1333+
},
1334+
{
1335+
"comment": "tb", # Unsupported resolving from local imports
13341336
},
13351337
)
13361338

0 commit comments

Comments
 (0)