Skip to content

Trainer runs indefinitely if fast_dev_run=True when iterating over self.logger #10201

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
ritsuki1227 opened this issue Oct 27, 2021 · 4 comments · Fixed by #10232
Closed

Trainer runs indefinitely if fast_dev_run=True when iterating over self.logger #10201

ritsuki1227 opened this issue Oct 27, 2021 · 4 comments · Fixed by #10232
Labels
bug Something isn't working help wanted Open to be worked on logger Related to the Loggers
Milestone

Comments

@ritsuki1227
Copy link
Contributor

🐛 Bug

If you specify fast_dev_run=True and multiple loggers, and include a loop implementation over the LoggerCollection object for training, then the loop in the training runs indefinitely. If fast_dev_run=False, it runs correctly.

To Reproduce

A slightly modified version of the BoringModel script:

import os

import torch
from torch.utils.data import DataLoader, Dataset

from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.loggers import MLFlowLogger, TensorBoardLogger


class RandomDataset(Dataset):
    def __init__(self, size, length):
        self.len = length
        self.data = torch.randn(length, size)

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return self.len


class BoringModel(LightningModule):
    def __init__(self):
        super().__init__()
        self.layer = torch.nn.Linear(32, 2)

    def forward(self, x):
        return self.layer(x)

    def training_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("train_loss", loss)
        return {"loss": loss}
    
    def on_train_end(self):
        # added
        print('foo')
        for logger in self.logger:
            print('bar')

    def validation_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("valid_loss", loss)

    def test_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("test_loss", loss)

    def configure_optimizers(self):
        return torch.optim.SGD(self.layer.parameters(), lr=0.1)


def run():
    train_data = DataLoader(RandomDataset(32, 64), batch_size=2)
    val_data = DataLoader(RandomDataset(32, 64), batch_size=2)
    test_data = DataLoader(RandomDataset(32, 64), batch_size=2)
    
    tb_logger = TensorBoardLogger(save_dir='foo') # added
    mlf_logger = MLFlowLogger() # added

    model = BoringModel()
    trainer = Trainer(
        default_root_dir=os.getcwd(),
        limit_train_batches=1,
        limit_val_batches=1,
        num_sanity_val_steps=0,
        max_epochs=1,
        # enable_model_summary=False,
        fast_dev_run=True, # added
        logger=[tb_logger, mlf_logger] # added
    )
    trainer.fit(model, train_dataloaders=train_data, val_dataloaders=val_data)
    trainer.test(model, dataloaders=test_data)


if __name__ == "__main__":
    run()

Expected behavior

The training finishes when fast_dev_run=True.

Environment

  • PyTorch Lightning Version (e.g., 1.3.0): 1.4.9
  • PyTorch Version (e.g., 1.8): 1.9.0+cu111
  • Python version: 3.9.6
  • OS (e.g., Linux): Ubuntu 18.04
  • CUDA/cuDNN version: CUDA: 11.1.1, cuDNN: 8.0.5
  • GPU models and configuration: NVIDIA Tesla T4
  • How you installed PyTorch (conda, pip, source): source
  • If compiling from source, the output of torch.__config__.show():
'PyTorch built with:\n  - GCC 7.3\n  - C++ Version: 201402\n  - Intel(R) Math Kernel Library Version 2020.0.0 Product Build 20191122 for Intel(R) 64 architecture applications\n  - Intel(R) MKL-DNN v2.1.2 (Git Hash 98be7e8afa711dc9b66c8ff3504129cb82013cdb)\n  - OpenMP 201511 (a.k.a. OpenMP 4.5)\n  - NNPACK is enabled\n  - CPU capability usage: AVX2\n  - CUDA Runtime 11.1\n  - NVCC architecture flags: -gencode;arch=compute_37,code=sm_37;-gencode;arch=compute_50,code=sm_50;-gencode;arch=compute_60,code=sm_60;-gencode;arch=compute_70,code=sm_70;-gencode;arch=compute_75,code=sm_75;-gencode;arch=compute_80,code=sm_80;-gencode;arch=compute_86,code=sm_86\n  - CuDNN 8.0.5\n  - Magma 2.5.2\n  - Build settings: BLAS_INFO=mkl, BUILD_TYPE=Release, CUDA_VERSION=11.1, CUDNN_VERSION=8.0.5, CXX_COMPILER=/opt/rh/devtoolset-7/root/usr/bin/c++, CXX_FLAGS= -Wno-deprecated -fvisibility-inlines-hidden -DUSE_PTHREADPOOL -fopenmp -DNDEBUG -DUSE_KINETO -DUSE_FBGEMM -DUSE_QNNPACK -DUSE_PYTORCH_QNNPACK -DUSE_XNNPACK -DSYMBOLICATE_MOBILE_DEBUG_HANDLE -O2 -fPIC -Wno-narrowing -Wall -Wextra -Werror=return-type -Wno-missing-field-initializers -Wno-type-limits -Wno-array-bounds -Wno-unknown-pragmas -Wno-sign-compare -Wno-unused-parameter -Wno-unused-variable -Wno-unused-function -Wno-unused-result -Wno-unused-local-typedefs -Wno-strict-overflow -Wno-strict-aliasing -Wno-error=deprecated-declarations -Wno-stringop-overflow -Wno-psabi -Wno-error=pedantic -Wno-error=redundant-decls -Wno-error=old-style-cast -fdiagnostics-color=always -faligned-new -Wno-unused-but-set-variable -Wno-maybe-uninitialized -fno-math-errno -fno-trapping-math -Werror=format -Wno-stringop-overflow, LAPACK_INFO=mkl, PERF_WITH_AVX=1, PERF_WITH_AVX2=1, PERF_WITH_AVX512=1, TORCH_VERSION=1.9.0, USE_CUDA=ON, USE_CUDNN=ON, USE_EXCEPTION_PTR=1, USE_GFLAGS=OFF, USE_GLOG=OFF, USE_MKL=ON, USE_MKLDNN=ON, USE_MPI=OFF, USE_NCCL=ON, USE_NNPACK=ON, USE_OPENMP=ON, \n'
  • Any other relevant information:

Additional context

@ritsuki1227 ritsuki1227 added bug Something isn't working help wanted Open to be worked on labels Oct 27, 2021
@Programmer-RD-AI
Copy link
Contributor

Programmer-RD-AI commented Oct 28, 2021

@ritsuki1227
Copy link
Contributor Author

@Programmer-RD-AI
Thank you for sharing.
FYI, it can also be reproduced on my macOS environment without CUDA and GPUs.
I believe the problem is related to DummyLogger class of pytorch lightning rather than naive pytorch.

@Programmer-RD-AI
Copy link
Contributor

Hi,

OK I will try to find a solution

@awaelchli
Copy link
Contributor

@ritsuki1227 thanks for reporting this and the repro script (extremely valuable)! It is a fast_dev_run bug and you are right with your suspicion about dummy logger. We need to add something like

    def __iter__(self):
        yield from ()

to the DummyLogger

@awaelchli awaelchli changed the title trainer.fit runs indefinitely if fast_dev_run=True with multiple loggers Trainer runs indefinitely if fast_dev_run=True when iterating over self.logger Oct 29, 2021
@awaelchli awaelchli added this to the v1.5 milestone Oct 29, 2021
@awaelchli awaelchli added the logger Related to the Loggers label Oct 29, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working help wanted Open to be worked on logger Related to the Loggers
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants