Skip to content

check_val_every_n_epoch bug with list of dataloaders #12145

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
TrentBrick opened this issue Feb 28, 2022 · 4 comments
Open

check_val_every_n_epoch bug with list of dataloaders #12145

TrentBrick opened this issue Feb 28, 2022 · 4 comments
Assignees
Labels
bug Something isn't working loops Related to the Loop API
Milestone

Comments

@TrentBrick
Copy link

TrentBrick commented Feb 28, 2022

🐛 Bug

If you have a list of dataloaders (used in a continual learning setting) and check_val_every_n_epoch in the trainer is not equal to 1 then validation_step() won't return the index of the dataloader currently being used.

To Reproduce

Dataloader for validation:

def val_dataloader(self):
    return [DataLoader(ds, batch_size=self.batch_size, shuffle=True , num_workers=self.num_workers) for ds in self.test_datasets[:self.curr_index+1]]

Validation step:

def validation_step(self, val_batch, batch_idx, dataloader_idx=None):
    if dataloader_idx:
        print("data loader index provided")

Run this in a training loop with check_val_every_n_epoch=1 and compare to check_val_every_n_epoch=5. The latter will always have dataloader_idx=None.

Expected behavior

The dataloader_idx is always correctly provided independently of how often the validation is called.

Environment

  • PyTorch Lightning Version (e.g., 1.5.0): 1.5.4
  • PyTorch Version (e.g., 1.10): 1.10.0 py3.8_cuda10.2_cudnn7.6.5_0
  • Python version (e.g., 3.9): 3.8.3
  • OS (e.g., Linux): Linux
  • CUDA/cuDNN version: 10.1
  • GPU models and configuration: v100
  • How you installed PyTorch (conda, pip, source): conda

Additional context

Thanks.

cc @carmocca @justusschock @ananthsub @ninginthecloud @rohitgr7

@carmocca
Copy link
Contributor

I cannot reproduce the issue. This is the code I'm using:

import os

import torch
from torch.utils.data import DataLoader, Dataset

from pytorch_lightning import LightningModule, Trainer


class RandomDataset(Dataset):
    def __init__(self, size, length):
        self.len = length
        self.data = torch.randn(length, size)

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return self.len


class BoringModel(LightningModule):
    def __init__(self):
        super().__init__()
        self.layer = torch.nn.Linear(32, 2)

    def forward(self, x):
        return self.layer(x)

    def training_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("train_loss", loss)
        return {"loss": loss}

    def validation_step(self, batch, batch_idx, dataloader_idx=None):
        print(f"{self.trainer.current_epoch=}, {dataloader_idx=}")
        loss = self(batch).sum()
        self.log("valid_loss", loss)

    def configure_optimizers(self):
        return torch.optim.SGD(self.layer.parameters(), lr=0.1)

    def train_dataloader(self):
        return DataLoader(RandomDataset(32, 64))

    def val_dataloader(self):
        return [DataLoader(RandomDataset(32, 64)), DataLoader(RandomDataset(32, 64))]


def run():
    model = BoringModel()
    trainer = Trainer(
        default_root_dir=os.getcwd(),
        limit_train_batches=1,
        limit_val_batches=1,
        num_sanity_val_steps=0,
        check_val_every_n_epoch=5,
        max_epochs=10,
        enable_model_summary=False,
        enable_progress_bar=False,
        logger=False,
        enable_checkpointing=False,
    )
    trainer.fit(model)


if __name__ == "__main__":
    run()

Which prints as expected:

self.trainer.current_epoch=4, dataloader_idx=0
self.trainer.current_epoch=4, dataloader_idx=1
self.trainer.current_epoch=9, dataloader_idx=0
self.trainer.current_epoch=9, dataloader_idx=1

Both with 1.5.4 and current master.

Can you adapt this script to reproduce it?

@carmocca carmocca added bug Something isn't working loops Related to the Loop API waiting on author Waiting on user action, correction, or update labels Feb 28, 2022
@TrentBrick
Copy link
Author

TrentBrick commented Mar 2, 2022

Thanks for this minimum example. I have been able to reproduce my bug. There seems to be a weird interaction between check_val_every_n_epoch and reload_dataloaders_every_n_epochs where if the latter is larger than the former the error appears. When reload_dataloaders_every_n_epochs=4 or any value less than 5 when check_val_every_n_epoch=5 things work as they should. When reload_dataloaders_every_n_epochs=5 or any larger value I always get None values.

import os

import torch
from torch.utils.data import DataLoader, Dataset

from pytorch_lightning import LightningModule, Trainer


class RandomDataset(Dataset):
    def __init__(self, size, length):
        self.len = length
        self.data = torch.randn(length, size)
        

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return self.len


class BoringModel(LightningModule):
    def __init__(self):
        super().__init__()
        self.layer = torch.nn.Linear(32, 2)
        self.val_datasets = [DataLoader(RandomDataset(32, 64)), DataLoader(RandomDataset(32, 64))]
        self.curr_index = -1

    def forward(self, x):
        return self.layer(x)

    def training_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("train_loss", loss)
        return {"loss": loss}

    def validation_step(self, batch, batch_idx, dataloader_idx=None):
        print(f"{self.trainer.current_epoch=}, {dataloader_idx=}")
        loss = self(batch).sum()
        self.log("valid_loss", loss)

    def configure_optimizers(self):
        return torch.optim.SGD(self.layer.parameters(), lr=0.1)

    def train_dataloader(self):
        print("running train dataloader. Current index is:", self.curr_index, "current epoch is:", self.trainer.current_epoch)
        if self.curr_index< 1:
            self.curr_index += 1
        return DataLoader(RandomDataset(32, 64))

    def val_dataloader(self):
        return self.val_datasets[:self.curr_index+1]


def run():
    model = BoringModel()
    trainer = Trainer(
        default_root_dir=os.getcwd(),
        limit_train_batches=1,
        limit_val_batches=1,
        num_sanity_val_steps=0,
        check_val_every_n_epoch=5,
        max_epochs=20,
        enable_model_summary=False,
        enable_progress_bar=False,
        logger=False,
        enable_checkpointing=False,
        reload_dataloaders_every_n_epochs=5,
    )
    trainer.fit(model)


if __name__ == "__main__":
    run()

The only way to have reload_dataloaders_every_n_epochs>1 with the correct functionality is if check_val_every_n_epoch=1.

Thanks!

@carmocca carmocca removed the waiting on author Waiting on user action, correction, or update label Mar 3, 2022
@carmocca carmocca added this to the 1.5.x milestone Mar 3, 2022
@Borda Borda modified the milestones: 1.5.x, 1.6 Mar 21, 2022
@carmocca carmocca modified the milestones: 1.6, 1.6.x Mar 21, 2022
@rohitgr7
Copy link
Contributor

looks like the issue is with how both of these parameters operate:

if you set reload_dataloaders_every_n_epoch=5, it will reload the dataloaders at 0, 6, 11, 16th epoch. i.e with a gap of 5 epochs.

and if you set check_val_every_n_epoch=5, it will run the validation loop at 0, 5, 10, 15 and 20th epoch.

so the only difference is that one considers the epoch gap and other considers the exact epoch. This is what causing the issues above.

IMO, both should work in a similar way and we should choose one of the above criteria to make the fix.

@carmocca
Copy link
Contributor

0, 5, 10, 15 and 20th epoch.

I'd say this is the natural behaviour for most programmers.

But this was changed in #11036. Need to familiarize with the edge cases and tradeoffs if any

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working loops Related to the Loop API
Projects
None yet
Development

No branches or pull requests

4 participants