Skip to content

ModelCheckpoint does not infer save_dir, name, or version correctly when using multiple loggers #11682

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
sshkhr opened this issue Jan 31, 2022 · 4 comments
Assignees
Labels
bug Something isn't working callback: model checkpoint logger Related to the Loggers
Milestone

Comments

@sshkhr
Copy link

sshkhr commented Jan 31, 2022

🐛 Bug

When using multiple loggers with the Trainer, the ModelCheckpoint callback simply appends the experiment names and version names of both loggers to generate its dirpath. Furthermore, the checkpoints get saved in the current working directory (default_root_dir which is os.get_cwd() in the Trainer) instead of the save_dir specified in the loggers.

Logically, even when using multiple loggers with a Trainer we are still logging the same experiment. If the save_dir generated from multiple loggers is consistent then the checkpoints should be saved at that location, instead of saving in the default_root_dir

To Reproduce

In the BoringModel code, I added two loggers (CSVLogger and TensorBoardLogger), both with the same save directory ('logs') and version ('0' in example code - inferred by default). Now, add a ModelCheckpoint callback with no dirpath specified.

import os

import torch
from torch.utils.data import DataLoader, Dataset

from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.callbacks import ModelCheckpoint, TQDMProgressBar
from pytorch_lightning.loggers import CSVLogger, TensorBoardLogger

class RandomDataset(Dataset):
    def __init__(self, size, num_samples):
        self.len = num_samples
        self.data = torch.randn(num_samples, size)

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return self.len

num_samples = 10000

class BoringModel(LightningModule):
    def __init__(self):
        super().__init__()
        self.layer = torch.nn.Linear(32, 2)

    def forward(self, x):
        return self.layer(x)

    def training_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("train_loss", loss)
        return {"loss": loss}

    def validation_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("valid_loss", loss)

    def test_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("test_loss", loss)

    def configure_optimizers(self):
        return torch.optim.SGD(self.layer.parameters(), lr=0.1)

def run():
    train_data = DataLoader(RandomDataset(32, 64), batch_size=2)
    val_data = DataLoader(RandomDataset(32, 64), batch_size=2)
    test_data = DataLoader(RandomDataset(32, 64), batch_size=2)

    model = BoringModel()

    csv_logger = CSVLogger('logs')
    tb_logger = TensorBoardLogger('logs')
    progress_bar = TQDMProgressBar(refresh_rate=5)
    model_checkpoint = ModelCheckpoint(every_n_train_steps=1, save_top_k=-1)
    
    trainer = Trainer(
        default_root_dir=os.getcwd(),
        callbacks=[model_checkpoint, progress_bar],
        logger = [tb_logger, csv_logger],
        limit_train_batches=1,
        limit_val_batches=1,
        limit_test_batches=1,
        num_sanity_val_steps=0,
        max_epochs=5,
        enable_model_summary=False,
    )
    trainer.fit(model, train_dataloaders=train_data, val_dataloaders=val_data)
    trainer.test(model, dataloaders=test_data)

run()

Also reproduced on Colab notebook:

https://colab.research.google.com/drive/1zk77y5A9xkuAPN0Oip8pIkO_KgJvgx0H?usp=sharing

Expected behavior

The expected behaviour (for example what happens when using a single logger) is that the checkpoints folder gets created in the os.path.join(save_dir, name, version) as specified by the loggers.

/content/
└── logs
    └── default
        └── version_0
            └── checkpoints
                ├── epoch=0-step=0.ckpt
                ├── epoch=1-step=1.ckpt
                ├── epoch=2-step=2.ckpt
                ├── epoch=3-step=3.ckpt
                └── epoch=4-step=4.ckpt
            ├── events.out.tfevents.1643669231.090a64362fd5.305.0
            ├── events.out.tfevents.1643669232.090a64362fd5.305.1
            ├── hparams.yaml
            └── metrics.csv

However, if there is more than one logger then name and versions from all loggers are appended to generate the ModelCheckpoint dirpath (e.g. default_default instead of default, and 0_0 instead of 0). And the checkpoints get saved in the current working directory (default_root_dir which is os.getcwd() in the Trainer) instead of the save_dir specified in the logger.

/content/
├── default_default
│   └── 0_0
│       └── checkpoints
│           ├── epoch=0-step=0.ckpt
│           ├── epoch=1-step=1.ckpt
│           ├── epoch=2-step=2.ckpt
│           ├── epoch=3-step=3.ckpt
│           └── epoch=4-step=4.ckpt
└── logs
    └── default
        └── version_0
            ├── events.out.tfevents.1643669231.090a64362fd5.305.0
            ├── events.out.tfevents.1643669232.090a64362fd5.305.1
            ├── hparams.yaml
            └── metrics.csv

Environment

  • CUDA:
    • GPU:
      • Tesla K80
    • available: True
    • version: 11.1
  • Packages:
    • numpy: 1.19.5
    • pyTorch_debug: False
    • pyTorch_version: 1.10.0+cu111
    • pytorch-lightning: 1.5.9
    • tqdm: 4.62.3
  • System:
    • OS: Linux
    • architecture:
      • 64bit
    • processor: x86_64
    • python: 3.7.12
    • version: Proposal for help #1 SMP Tue Dec 7 09:58:10 PST 2021

Additional context

cc @awaelchli @edward-io @Borda @ananthsub @rohitgr7 @kamil-kaczmarek @Raalsky @Blaizzy @carmocca @ninginthecloud @jjenniferdai

@sshkhr sshkhr added the bug Something isn't working label Jan 31, 2022
@carmocca carmocca added callback: model checkpoint logger Related to the Loggers labels Feb 1, 2022
@daniellepintz
Copy link
Contributor

daniellepintz commented Feb 1, 2022

For name and version, this is because of how LoggerCollection is implemented: https://github.com/PyTorchLightning/pytorch-lightning/blob/86b177ebe5427725b35fde1a8808a7b59b8a277a/pytorch_lightning/loggers/base.py#L454-L463

We are in the process of deprecating LoggerCollection in #11232, and we will make sure to solve this as part of that work. cc @akashkw

@daniellepintz
Copy link
Contributor

@sshkhr Also this was fixed in #10976, if you are able to use the version of PTL on master this should solve your issue.

@carmocca carmocca added this to the 1.6 milestone Feb 1, 2022
@carmocca
Copy link
Contributor

carmocca commented Feb 1, 2022

Oh great. I'll close this then

@carmocca carmocca closed this as completed Feb 1, 2022
@sshkhr
Copy link
Author

sshkhr commented Feb 1, 2022

Appreciate the quick response and resolution @daniellepintz @carmocca !

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working callback: model checkpoint logger Related to the Loggers
Projects
None yet
Development

No branches or pull requests

4 participants