Skip to content

Commit 843c647

Browse files
ringohoffmanBorda
authored andcommitted
fix: Set tensorboard's global_step as the default wandb x-axis if sync_tensorboard=True (#20611)
(cherry picked from commit a3314d4)
1 parent c8bfa05 commit 843c647

File tree

4 files changed

+37
-4
lines changed

4 files changed

+37
-4
lines changed

src/lightning/pytorch/CHANGELOG.md

+12-1
Original file line numberDiff line numberDiff line change
@@ -8,16 +8,27 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
88

99
### Added
1010

11+
-
12+
13+
1114
### Changed
1215

13-
- Added a new `checkpoint_path_prefix` parameter to the MLflow logger which can control the path to where the MLflow artifacts for the model checkpoints are stored.
16+
- Change `wandb` default x-axis to `tensorboard`'s `global_step` when `sync_tensorboard=True` ([#20611](https://github.com/Lightning-AI/pytorch-lightning/pull/20611))
17+
18+
19+
- Added a new `checkpoint_path_prefix` parameter to the MLflow logger which can control the path to where the MLflow artifacts for the model checkpoints are stored ([#20538](https://github.com/Lightning-AI/pytorch-lightning/pull/20538))
20+
1421

1522
### Removed
1623

24+
-
25+
26+
1727
### Fixed
1828

1929
- Fix CSVLogger logging hyperparameter at every write which increase latency ([#20594](https://github.com/Lightning-AI/pytorch-lightning/pull/20594))
2030

31+
2132
- Always call `WandbLogger.experiment` first in `_call_setup_hook` to ensure `tensorboard` logs can sync to `wandb` ([#20610](https://github.com/Lightning-AI/pytorch-lightning/pull/20610))
2233

2334

src/lightning/pytorch/loggers/wandb.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -410,8 +410,11 @@ def experiment(self) -> Union["Run", "RunDisabled"]:
410410
if isinstance(self._experiment, (Run, RunDisabled)) and getattr(
411411
self._experiment, "define_metric", None
412412
):
413-
self._experiment.define_metric("trainer/global_step")
414-
self._experiment.define_metric("*", step_metric="trainer/global_step", step_sync=True)
413+
if self._wandb_init.get("sync_tensorboard"):
414+
self._experiment.define_metric("*", step_metric="global_step")
415+
else:
416+
self._experiment.define_metric("trainer/global_step")
417+
self._experiment.define_metric("*", step_metric="trainer/global_step", step_sync=True)
415418

416419
return self._experiment
417420

@@ -434,7 +437,7 @@ def log_metrics(self, metrics: Mapping[str, float], step: Optional[int] = None)
434437
assert rank_zero_only.rank == 0, "experiment tried to log from global_rank != 0"
435438

436439
metrics = _add_prefix(metrics, self._prefix, self.LOGGER_JOIN_CHAR)
437-
if step is not None:
440+
if step is not None and not self._wandb_init.get("sync_tensorboard"):
438441
self.experiment.log(dict(metrics, **{"trainer/global_step": step}))
439442
else:
440443
self.experiment.log(metrics)

tests/tests_pytorch/loggers/conftest.py

+1
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ class RunType: # to make isinstance checks pass
5555
watch=Mock(),
5656
log_artifact=Mock(),
5757
use_artifact=Mock(),
58+
define_metric=Mock(),
5859
id="run_id",
5960
)
6061

tests/tests_pytorch/loggers/test_wandb.py

+18
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,24 @@ def test_wandb_logger_init(wandb_mock):
126126
assert logger.version == wandb_mock.init().id
127127

128128

129+
def test_wandb_logger_sync_tensorboard(wandb_mock):
130+
logger = WandbLogger(sync_tensorboard=True)
131+
wandb_mock.run = None
132+
logger.experiment
133+
134+
# test that tensorboard's global_step is set as the default x-axis if sync_tensorboard=True
135+
wandb_mock.init.return_value.define_metric.assert_called_once_with("*", step_metric="global_step")
136+
137+
138+
def test_wandb_logger_sync_tensorboard_log_metrics(wandb_mock):
139+
logger = WandbLogger(sync_tensorboard=True)
140+
metrics = {"loss": 1e-3, "accuracy": 0.99}
141+
logger.log_metrics(metrics)
142+
143+
# test that trainer/global_step is not added to the logged metrics if sync_tensorboard=True
144+
wandb_mock.run.log.assert_called_once_with(metrics)
145+
146+
129147
def test_wandb_logger_init_before_spawn(wandb_mock):
130148
logger = WandbLogger()
131149
assert logger._experiment is None

0 commit comments

Comments
 (0)