Skip to content

Commit eac6c3f

Browse files
lantigaBorda
authored andcommitted
Switch from tensorboard to tensorboardx in logger (#15728)
* Switch from tensorboard to tensorboardx in logger * Warn if log_graph is set to True but tensorboard is not installed * Fix warning message formatting * Apply suggestions from code review * simplify for TBX as required pkg * docs example * chlog * tbx 2.2 Co-authored-by: Luca Antiga <[email protected]> Co-authored-by: William Falcon <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Carlos Mocholí <[email protected]> Co-authored-by: Jirka Borovec <[email protected]> Co-authored-by: Jirka <[email protected]> (cherry picked from commit 9c2eb52)
1 parent eeb7166 commit eac6c3f

File tree

7 files changed

+30
-8
lines changed

7 files changed

+30
-8
lines changed

requirements/pytorch/base.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ torch>=1.9.*, <=1.13.0
66
tqdm>=4.57.0, <4.65.0
77
PyYAML>=5.4, <=6.0
88
fsspec[http]>2021.06.0, <2022.8.0
9-
tensorboard>=2.9.1, <2.12.0
9+
tensorboardX>=2.2, <=2.5.1 # min version is set by torch.onnx missing attribute
1010
torchmetrics>=0.7.0, <0.10.1 # needed for using fixed compare_version
1111
packaging>=17.0, <=21.3
1212
typing-extensions>=4.0.0, <=4.4.0

requirements/pytorch/extra.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,4 +7,3 @@ omegaconf>=2.0.5, <2.3.0
77
hydra-core>=1.0.5, <1.3.0
88
jsonargparse[signatures]>=4.15.2, <4.16.0
99
rich>=10.14.0, !=10.15.0.a, <13.0.0
10-
protobuf<=3.20.1 # strict # an extra is updating protobuf, this pin prevents TensorBoard failure

requirements/pytorch/test.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,6 @@ psutil<5.9.4 # for `DeviceStatsMonitor`
1414
pandas>1.0, <1.5.2 # needed in benchmarks
1515
fastapi<0.87.0
1616
uvicorn<0.19.1
17+
18+
tensorboard>=2.9.1, <2.12.0
19+
protobuf<=3.20.1 # strict # an extra is updating protobuf, this pin prevents TensorBoard failure

src/pytorch_lightning/CHANGELOG.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1616
- Temporarily removed support for Hydra multi-run ([#15737](https://github.com/Lightning-AI/lightning/pull/15737))
1717

1818

19+
- Switch from `tensorboard` to `tensorboardx` in `TensorBoardLogger` ([#15728](https://github.com/Lightning-AI/lightning/pull/15728))
20+
21+
1922
### Fixed
2023

2124
-
@@ -46,7 +49,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
4649

4750
## [1.8.0] - 2022-11-01
4851

49-
5052
### Added
5153

5254
- Added support for requeueing slurm array jobs ([#15040](https://github.com/Lightning-AI/lightning/pull/15040))

src/pytorch_lightning/loggers/tensorboard.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,10 @@
2222
from typing import Any, Dict, Mapping, Optional, Union
2323

2424
import numpy as np
25+
from lightning_utilities.core.imports import RequirementCache
26+
from tensorboardX import SummaryWriter
27+
from tensorboardX.summary import hparams
2528
from torch import Tensor
26-
from torch.utils.tensorboard import SummaryWriter
27-
from torch.utils.tensorboard.summary import hparams
2829

2930
import pytorch_lightning as pl
3031
from lightning_lite.utilities.cloud_io import get_filesystem
@@ -38,6 +39,8 @@
3839

3940
log = logging.getLogger(__name__)
4041

42+
_TENSORBOARD_AVAILABLE = RequirementCache("tensorboard")
43+
4144
if _OMEGACONF_AVAILABLE:
4245
from omegaconf import Container, OmegaConf
4346

@@ -46,7 +49,7 @@ class TensorBoardLogger(Logger):
4649
r"""
4750
Log to local file system in `TensorBoard <https://www.tensorflow.org/tensorboard>`_ format.
4851
49-
Implemented using :class:`~torch.utils.tensorboard.SummaryWriter`. Logs are saved to
52+
Implemented using :class:`~tensorboardX.SummaryWriter`. Logs are saved to
5053
``os.path.join(save_dir, name, version)``. This is the default logger in Lightning, it comes
5154
preinstalled.
5255
@@ -77,11 +80,20 @@ class TensorBoardLogger(Logger):
7780
sub_dir: Sub-directory to group TensorBoard logs. If a sub_dir argument is passed
7881
then logs are saved in ``/save_dir/name/version/sub_dir/``. Defaults to ``None`` in which
7982
logs are saved in ``/save_dir/name/version/``.
80-
\**kwargs: Additional arguments used by :class:`SummaryWriter` can be passed as keyword
83+
\**kwargs: Additional arguments used by :class:`tensorboardX.SummaryWriter` can be passed as keyword
8184
arguments in this logger. To automatically flush to disk, `max_queue` sets the size
8285
of the queue for pending logs before flushing. `flush_secs` determines how many seconds
8386
elapses before flushing.
8487
88+
Example:
89+
>>> import shutil, tempfile
90+
>>> tmp = tempfile.mkdtemp()
91+
>>> tbl = TensorBoardLogger(tmp)
92+
>>> tbl.log_hyperparams({"epochs": 5, "optimizer": "Adam"})
93+
>>> tbl.log_metrics({"acc": 0.75})
94+
>>> tbl.log_metrics({"acc": 0.9})
95+
>>> tbl.finalize("success")
96+
>>> shutil.rmtree(tmp)
8597
"""
8698
NAME_HPARAMS_FILE = "hparams.yaml"
8799
LOGGER_JOIN_CHAR = "-"
@@ -103,7 +115,10 @@ def __init__(
103115
self._name = name or ""
104116
self._version = version
105117
self._sub_dir = None if sub_dir is None else os.fspath(sub_dir)
106-
self._log_graph = log_graph
118+
if log_graph and not _TENSORBOARD_AVAILABLE:
119+
rank_zero_warn("You set `TensorBoardLogger(log_graph=True)` but `tensorboard` is not available.")
120+
self._log_graph = log_graph and _TENSORBOARD_AVAILABLE
121+
107122
self._default_hp_metric = default_hp_metric
108123
self._prefix = prefix
109124
self._fs = get_filesystem(save_dir)

tests/tests_pytorch/conftest.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ def restore_env_variables():
7575
"CUDA_MODULE_LOADING", # leaked since PyTorch 1.13
7676
"KMP_INIT_AT_FORK", # leaked since PyTorch 1.13
7777
"KMP_DUPLICATE_LIB_OK", # leaked since PyTorch 1.13
78+
"CRC32C_SW_MODE", # leaked by tensorboardX
7879
}
7980
leaked_vars.difference_update(allowlist)
8081
assert not leaked_vars, f"test is leaking environment variable(s): {set(leaked_vars)}"

tests/tests_pytorch/loggers/test_tensorboard.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from pytorch_lightning import Trainer
2525
from pytorch_lightning.demos.boring_classes import BoringModel
2626
from pytorch_lightning.loggers import TensorBoardLogger
27+
from pytorch_lightning.loggers.tensorboard import _TENSORBOARD_AVAILABLE
2728
from pytorch_lightning.utilities.imports import _OMEGACONF_AVAILABLE
2829
from tests_pytorch.helpers.runif import RunIf
2930

@@ -220,6 +221,7 @@ def test_tensorboard_log_graph(tmpdir, example_input_array):
220221
logger.log_graph(model, example_input_array)
221222

222223

224+
@pytest.mark.skipif(not _TENSORBOARD_AVAILABLE, reason=str(_TENSORBOARD_AVAILABLE))
223225
def test_tensorboard_log_graph_warning_no_example_input_array(tmpdir):
224226
"""test that log graph throws warning if model.example_input_array is None."""
225227
model = BoringModel()

0 commit comments

Comments
 (0)