-
Notifications
You must be signed in to change notification settings - Fork 3.5k
check_val_every_n_epoch bug with list of dataloaders #12145
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
I cannot reproduce the issue. This is the code I'm using: import os
import torch
from torch.utils.data import DataLoader, Dataset
from pytorch_lightning import LightningModule, Trainer
class RandomDataset(Dataset):
def __init__(self, size, length):
self.len = length
self.data = torch.randn(length, size)
def __getitem__(self, index):
return self.data[index]
def __len__(self):
return self.len
class BoringModel(LightningModule):
def __init__(self):
super().__init__()
self.layer = torch.nn.Linear(32, 2)
def forward(self, x):
return self.layer(x)
def training_step(self, batch, batch_idx):
loss = self(batch).sum()
self.log("train_loss", loss)
return {"loss": loss}
def validation_step(self, batch, batch_idx, dataloader_idx=None):
print(f"{self.trainer.current_epoch=}, {dataloader_idx=}")
loss = self(batch).sum()
self.log("valid_loss", loss)
def configure_optimizers(self):
return torch.optim.SGD(self.layer.parameters(), lr=0.1)
def train_dataloader(self):
return DataLoader(RandomDataset(32, 64))
def val_dataloader(self):
return [DataLoader(RandomDataset(32, 64)), DataLoader(RandomDataset(32, 64))]
def run():
model = BoringModel()
trainer = Trainer(
default_root_dir=os.getcwd(),
limit_train_batches=1,
limit_val_batches=1,
num_sanity_val_steps=0,
check_val_every_n_epoch=5,
max_epochs=10,
enable_model_summary=False,
enable_progress_bar=False,
logger=False,
enable_checkpointing=False,
)
trainer.fit(model)
if __name__ == "__main__":
run() Which prints as expected: self.trainer.current_epoch=4, dataloader_idx=0
self.trainer.current_epoch=4, dataloader_idx=1
self.trainer.current_epoch=9, dataloader_idx=0
self.trainer.current_epoch=9, dataloader_idx=1 Both with 1.5.4 and current master. Can you adapt this script to reproduce it? |
Thanks for this minimum example. I have been able to reproduce my bug. There seems to be a weird interaction between import os
import torch
from torch.utils.data import DataLoader, Dataset
from pytorch_lightning import LightningModule, Trainer
class RandomDataset(Dataset):
def __init__(self, size, length):
self.len = length
self.data = torch.randn(length, size)
def __getitem__(self, index):
return self.data[index]
def __len__(self):
return self.len
class BoringModel(LightningModule):
def __init__(self):
super().__init__()
self.layer = torch.nn.Linear(32, 2)
self.val_datasets = [DataLoader(RandomDataset(32, 64)), DataLoader(RandomDataset(32, 64))]
self.curr_index = -1
def forward(self, x):
return self.layer(x)
def training_step(self, batch, batch_idx):
loss = self(batch).sum()
self.log("train_loss", loss)
return {"loss": loss}
def validation_step(self, batch, batch_idx, dataloader_idx=None):
print(f"{self.trainer.current_epoch=}, {dataloader_idx=}")
loss = self(batch).sum()
self.log("valid_loss", loss)
def configure_optimizers(self):
return torch.optim.SGD(self.layer.parameters(), lr=0.1)
def train_dataloader(self):
print("running train dataloader. Current index is:", self.curr_index, "current epoch is:", self.trainer.current_epoch)
if self.curr_index< 1:
self.curr_index += 1
return DataLoader(RandomDataset(32, 64))
def val_dataloader(self):
return self.val_datasets[:self.curr_index+1]
def run():
model = BoringModel()
trainer = Trainer(
default_root_dir=os.getcwd(),
limit_train_batches=1,
limit_val_batches=1,
num_sanity_val_steps=0,
check_val_every_n_epoch=5,
max_epochs=20,
enable_model_summary=False,
enable_progress_bar=False,
logger=False,
enable_checkpointing=False,
reload_dataloaders_every_n_epochs=5,
)
trainer.fit(model)
if __name__ == "__main__":
run() The only way to have Thanks! |
looks like the issue is with how both of these parameters operate: if you set and if you set so the only difference is that one considers the epoch gap and other considers the exact epoch. This is what causing the issues above. IMO, both should work in a similar way and we should choose one of the above criteria to make the fix. |
I'd say this is the natural behaviour for most programmers. But this was changed in #11036. Need to familiarize with the edge cases and tradeoffs if any |
Uh oh!
There was an error while loading. Please reload this page.
🐛 Bug
If you have a list of dataloaders (used in a continual learning setting) and
check_val_every_n_epoch
in the trainer is not equal to 1 thenvalidation_step()
won't return the index of the dataloader currently being used.To Reproduce
Dataloader for validation:
Validation step:
Run this in a training loop with
check_val_every_n_epoch=1
and compare tocheck_val_every_n_epoch=5
. The latter will always havedataloader_idx=None
.Expected behavior
The dataloader_idx is always correctly provided independently of how often the validation is called.
Environment
conda
,pip
, source): condaAdditional context
Thanks.
cc @carmocca @justusschock @ananthsub @ninginthecloud @rohitgr7
The text was updated successfully, but these errors were encountered: