Skip to content

fix: always call WandbLogger.experiment first in _call_setup_hook to ensure tensorboard logs can sync to wandb #20610

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Fix CSVLogger logging hyperparameter at every write which increase latency ([#20594](https://github.com/Lightning-AI/pytorch-lightning/pull/20594))

- Always call `WandbLogger.experiment` first in `_call_setup_hook` to ensure `tensorboard` logs can sync to `wandb` ([#20610](https://github.com/Lightning-AI/pytorch-lightning/pull/20610))


## [2.5.0] - 2024-12-19

Expand Down
7 changes: 6 additions & 1 deletion src/lightning/pytorch/trainer/call.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import lightning.pytorch as pl
from lightning.fabric.utilities.device_dtype_mixin import _DeviceDtypeModuleMixin
from lightning.pytorch.callbacks import Checkpoint, EarlyStopping
from lightning.pytorch.loggers import WandbLogger
from lightning.pytorch.strategies.launchers import _SubprocessScriptLauncher
from lightning.pytorch.trainer.connectors.signal_connector import _get_sigkill_signal
from lightning.pytorch.trainer.states import TrainerStatus
Expand Down Expand Up @@ -91,8 +92,12 @@ def _call_setup_hook(trainer: "pl.Trainer") -> None:
if isinstance(module, _DeviceDtypeModuleMixin):
module._device = trainer.strategy.root_device

# wandb.init must be called before any tensorboard writers are created in order to sync tensorboard logs to wandb:
# https://github.com/wandb/wandb/issues/1782#issuecomment-779161203
loggers = sorted(trainer.loggers, key=lambda logger: not isinstance(logger, WandbLogger))

# Trigger lazy creation of experiment in loggers so loggers have their metadata available
for logger in trainer.loggers:
for logger in loggers:
if hasattr(logger, "experiment"):
_ = logger.experiment

Expand Down
39 changes: 38 additions & 1 deletion tests/tests_pytorch/loggers/test_wandb.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.cli import LightningCLI
from lightning.pytorch.demos.boring_classes import BoringModel
from lightning.pytorch.loggers import WandbLogger
from lightning.pytorch.loggers import TensorBoardLogger, WandbLogger
from lightning.pytorch.utilities.exceptions import MisconfigurationException
from tests_pytorch.test_cli import _xfail_python_ge_3_11_9

Expand Down Expand Up @@ -133,6 +133,43 @@ def test_wandb_logger_init_before_spawn(wandb_mock):
assert logger._experiment is not None


def test_wandb_logger_experiment_called_first(wandb_mock, tmp_path):
wandb_experiment_called = False

def tensorboard_experiment_side_effect() -> mock.MagicMock:
nonlocal wandb_experiment_called
assert wandb_experiment_called
return mock.MagicMock()

def wandb_experiment_side_effect() -> mock.MagicMock:
nonlocal wandb_experiment_called
wandb_experiment_called = True
return mock.MagicMock()

with (
mock.patch.object(
TensorBoardLogger,
"experiment",
new_callable=lambda: mock.PropertyMock(side_effect=tensorboard_experiment_side_effect),
),
mock.patch.object(
WandbLogger,
"experiment",
new_callable=lambda: mock.PropertyMock(side_effect=wandb_experiment_side_effect),
),
):
model = BoringModel()
trainer = Trainer(
default_root_dir=tmp_path,
log_every_n_steps=1,
limit_train_batches=0,
limit_val_batches=0,
max_steps=1,
logger=[TensorBoardLogger(tmp_path), WandbLogger(save_dir=tmp_path)],
)
trainer.fit(model)


def test_wandb_pickle(wandb_mock, tmp_path):
"""Verify that pickling trainer with wandb logger works.

Expand Down
Loading