From 8dbd103f3a16f1e4642140bed0fc5bd908620b1e Mon Sep 17 00:00:00 2001 From: Matthew Hoffman Date: Sun, 2 Mar 2025 00:43:45 -0600 Subject: [PATCH 1/5] fix: always call WandbLogger.experiment first in _call_setup_hook to ensure tensorboard logs sync to wandb https://github.com/wandb/wandb/issues/1782#issuecomment-779161203 --- src/lightning/pytorch/trainer/call.py | 7 +++- tests/tests_pytorch/trainer/test_trainer.py | 39 ++++++++++++++++++++- 2 files changed, 44 insertions(+), 2 deletions(-) diff --git a/src/lightning/pytorch/trainer/call.py b/src/lightning/pytorch/trainer/call.py index 012d1a2152aa3..b5354eb2b08dd 100644 --- a/src/lightning/pytorch/trainer/call.py +++ b/src/lightning/pytorch/trainer/call.py @@ -21,6 +21,7 @@ import lightning.pytorch as pl from lightning.fabric.utilities.device_dtype_mixin import _DeviceDtypeModuleMixin from lightning.pytorch.callbacks import Checkpoint, EarlyStopping +from lightning.pytorch.loggers import WandbLogger from lightning.pytorch.strategies.launchers import _SubprocessScriptLauncher from lightning.pytorch.trainer.connectors.signal_connector import _get_sigkill_signal from lightning.pytorch.trainer.states import TrainerStatus @@ -91,8 +92,12 @@ def _call_setup_hook(trainer: "pl.Trainer") -> None: if isinstance(module, _DeviceDtypeModuleMixin): module._device = trainer.strategy.root_device + # wandb.init must be called before any tensorboard writers are created in order to sync tensorboard logs to wandb: + # https://github.com/wandb/wandb/issues/1782#issuecomment-779161203 + loggers = sorted(trainer.loggers, key=lambda logger: not isinstance(logger, WandbLogger)) + # Trigger lazy creation of experiment in loggers so loggers have their metadata available - for logger in trainer.loggers: + for logger in loggers: if hasattr(logger, "experiment"): _ = logger.experiment diff --git a/tests/tests_pytorch/trainer/test_trainer.py b/tests/tests_pytorch/trainer/test_trainer.py index 18ae7ce77bdfc..9bf847373c89f 100644 --- a/tests/tests_pytorch/trainer/test_trainer.py +++ b/tests/tests_pytorch/trainer/test_trainer.py @@ -49,7 +49,7 @@ RandomIterableDataset, RandomIterableDatasetWithLen, ) -from lightning.pytorch.loggers import TensorBoardLogger +from lightning.pytorch.loggers import TensorBoardLogger, WandbLogger from lightning.pytorch.overrides.distributed import UnrepeatedDistributedSampler, _IndexBatchSamplerWrapper from lightning.pytorch.strategies import DDPStrategy, SingleDeviceStrategy from lightning.pytorch.strategies.launchers import _MultiProcessingLauncher, _SubprocessScriptLauncher @@ -1271,6 +1271,43 @@ def training_step(self, *args, **kwargs): log_metrics_mock.assert_has_calls(expected_calls) +def test_wandb_logger_experiment_called_first(tmp_path): + wandb_experiment_called = False + + def tensorboard_experiment_side_effect() -> mock.MagicMock: + nonlocal wandb_experiment_called + assert wandb_experiment_called + return mock.MagicMock() + + def wandb_experiment_side_effect() -> mock.MagicMock: + nonlocal wandb_experiment_called + wandb_experiment_called = True + return mock.MagicMock() + + with ( + mock.patch.object( + TensorBoardLogger, + "experiment", + new_callable=lambda: mock.PropertyMock(side_effect=tensorboard_experiment_side_effect), + ), + mock.patch.object( + WandbLogger, + "experiment", + new_callable=lambda: mock.PropertyMock(side_effect=wandb_experiment_side_effect), + ), + ): + model = BoringModel() + trainer = Trainer( + default_root_dir=tmp_path, + log_every_n_steps=1, + limit_train_batches=0, + limit_val_batches=0, + max_steps=1, + logger=[TensorBoardLogger(tmp_path), WandbLogger(save_dir=tmp_path)], + ) + trainer.fit(model) + + class TestLightningDataModule(LightningDataModule): def __init__(self, dataloaders): super().__init__() From 35a62e7e1ee644a37ce34a0c71b1e1940e71dbe3 Mon Sep 17 00:00:00 2001 From: Matthew Hoffman Date: Sun, 2 Mar 2025 01:05:53 -0600 Subject: [PATCH 2/5] chore: update changelog for #20610 --- src/lightning/pytorch/CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index 8bc8e45989f77..54cd0978e3604 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -15,6 +15,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Fixed - Fix CSVLogger logging hyperparameter at every write which increase latency ([#20594](https://github.com/Lightning-AI/pytorch-lightning/pull/20594)) +- always call `WandbLogger.experiment` first in `_call_setup_hook` to ensure `tensorboard` logs can sync to `wandb` ([#20610](https://github.com/Lightning-AI/pytorch-lightning/pull/20610)) ## [2.5.0] - 2024-12-19 From 7439a0451eaaac9c2dc92983a4e13821354f73d5 Mon Sep 17 00:00:00 2001 From: Matthew Hoffman Date: Sun, 2 Mar 2025 01:34:59 -0600 Subject: [PATCH 3/5] test: Move WandbLogger test to test_wandb the trainer tests don't run with wandb installed, so we can't put it there --- tests/tests_pytorch/loggers/test_wandb.py | 39 ++++++++++++++++++++- tests/tests_pytorch/trainer/test_trainer.py | 39 +-------------------- 2 files changed, 39 insertions(+), 39 deletions(-) diff --git a/tests/tests_pytorch/loggers/test_wandb.py b/tests/tests_pytorch/loggers/test_wandb.py index f3d82b0582be2..52ad03bd994b4 100644 --- a/tests/tests_pytorch/loggers/test_wandb.py +++ b/tests/tests_pytorch/loggers/test_wandb.py @@ -24,7 +24,7 @@ from lightning.pytorch.callbacks import ModelCheckpoint from lightning.pytorch.cli import LightningCLI from lightning.pytorch.demos.boring_classes import BoringModel -from lightning.pytorch.loggers import WandbLogger +from lightning.pytorch.loggers import TensorBoardLogger, WandbLogger from lightning.pytorch.utilities.exceptions import MisconfigurationException from tests_pytorch.test_cli import _xfail_python_ge_3_11_9 @@ -133,6 +133,43 @@ def test_wandb_logger_init_before_spawn(wandb_mock): assert logger._experiment is not None +def test_wandb_logger_experiment_called_first(wandb_mock, tmp_path): + wandb_experiment_called = False + + def tensorboard_experiment_side_effect() -> mock.MagicMock: + nonlocal wandb_experiment_called + assert wandb_experiment_called + return mock.MagicMock() + + def wandb_experiment_side_effect() -> mock.MagicMock: + nonlocal wandb_experiment_called + wandb_experiment_called = True + return mock.MagicMock() + + with ( + mock.patch.object( + TensorBoardLogger, + "experiment", + new_callable=lambda: mock.PropertyMock(side_effect=tensorboard_experiment_side_effect), + ), + mock.patch.object( + WandbLogger, + "experiment", + new_callable=lambda: mock.PropertyMock(side_effect=wandb_experiment_side_effect), + ), + ): + model = BoringModel() + trainer = Trainer( + default_root_dir=tmp_path, + log_every_n_steps=1, + limit_train_batches=0, + limit_val_batches=0, + max_steps=1, + logger=[TensorBoardLogger(tmp_path), WandbLogger(save_dir=tmp_path)], + ) + trainer.fit(model) + + def test_wandb_pickle(wandb_mock, tmp_path): """Verify that pickling trainer with wandb logger works. diff --git a/tests/tests_pytorch/trainer/test_trainer.py b/tests/tests_pytorch/trainer/test_trainer.py index 9bf847373c89f..18ae7ce77bdfc 100644 --- a/tests/tests_pytorch/trainer/test_trainer.py +++ b/tests/tests_pytorch/trainer/test_trainer.py @@ -49,7 +49,7 @@ RandomIterableDataset, RandomIterableDatasetWithLen, ) -from lightning.pytorch.loggers import TensorBoardLogger, WandbLogger +from lightning.pytorch.loggers import TensorBoardLogger from lightning.pytorch.overrides.distributed import UnrepeatedDistributedSampler, _IndexBatchSamplerWrapper from lightning.pytorch.strategies import DDPStrategy, SingleDeviceStrategy from lightning.pytorch.strategies.launchers import _MultiProcessingLauncher, _SubprocessScriptLauncher @@ -1271,43 +1271,6 @@ def training_step(self, *args, **kwargs): log_metrics_mock.assert_has_calls(expected_calls) -def test_wandb_logger_experiment_called_first(tmp_path): - wandb_experiment_called = False - - def tensorboard_experiment_side_effect() -> mock.MagicMock: - nonlocal wandb_experiment_called - assert wandb_experiment_called - return mock.MagicMock() - - def wandb_experiment_side_effect() -> mock.MagicMock: - nonlocal wandb_experiment_called - wandb_experiment_called = True - return mock.MagicMock() - - with ( - mock.patch.object( - TensorBoardLogger, - "experiment", - new_callable=lambda: mock.PropertyMock(side_effect=tensorboard_experiment_side_effect), - ), - mock.patch.object( - WandbLogger, - "experiment", - new_callable=lambda: mock.PropertyMock(side_effect=wandb_experiment_side_effect), - ), - ): - model = BoringModel() - trainer = Trainer( - default_root_dir=tmp_path, - log_every_n_steps=1, - limit_train_batches=0, - limit_val_batches=0, - max_steps=1, - logger=[TensorBoardLogger(tmp_path), WandbLogger(save_dir=tmp_path)], - ) - trainer.fit(model) - - class TestLightningDataModule(LightningDataModule): def __init__(self, dataloaders): super().__init__() From 48694b103193aa78a74d20cd30afe7f82924ee09 Mon Sep 17 00:00:00 2001 From: Matthew Hoffman Date: Sun, 2 Mar 2025 23:07:39 -0600 Subject: [PATCH 4/5] chore: Add blank line --- src/lightning/pytorch/CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index 54cd0978e3604..2f69c60696b39 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -15,6 +15,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Fixed - Fix CSVLogger logging hyperparameter at every write which increase latency ([#20594](https://github.com/Lightning-AI/pytorch-lightning/pull/20594)) + - always call `WandbLogger.experiment` first in `_call_setup_hook` to ensure `tensorboard` logs can sync to `wandb` ([#20610](https://github.com/Lightning-AI/pytorch-lightning/pull/20610)) From 2e3dd5754447dfc858427ba6c54d025cf397d902 Mon Sep 17 00:00:00 2001 From: Matthew Hoffman Date: Sun, 2 Mar 2025 23:08:13 -0600 Subject: [PATCH 5/5] style: Capitalize first word --- src/lightning/pytorch/CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index 2f69c60696b39..d6b0ef04c759e 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -16,7 +16,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fix CSVLogger logging hyperparameter at every write which increase latency ([#20594](https://github.com/Lightning-AI/pytorch-lightning/pull/20594)) -- always call `WandbLogger.experiment` first in `_call_setup_hook` to ensure `tensorboard` logs can sync to `wandb` ([#20610](https://github.com/Lightning-AI/pytorch-lightning/pull/20610)) +- Always call `WandbLogger.experiment` first in `_call_setup_hook` to ensure `tensorboard` logs can sync to `wandb` ([#20610](https://github.com/Lightning-AI/pytorch-lightning/pull/20610)) ## [2.5.0] - 2024-12-19