Skip to content

Commit 14c552b

Browse files
authored
[bugfix] Fix dataloading for iterable datasets and limit_train_batches (#7306)
* bugfix-dataloading * rm-logs * Update CHANGELOG.md * Update test_dataloaders.py * Update test_dataloaders.py * Update training_loop.py * Update test_dataloaders.py * Update CHANGELOG.md * Update CHANGELOG.md * Update test_dataloaders.py * Update training_loop.py * Update training_loop.py * comments * address comments * more tests * Update progress.py * Update test_dataloaders.py * Update test_dataloaders.py * Update training_loop.py * Update training_loop.py * test ckpt fix? * update again
1 parent 7636d42 commit 14c552b

File tree

6 files changed

+241
-42
lines changed

6 files changed

+241
-42
lines changed

CHANGELOG.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -297,6 +297,12 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
297297
### Fixed
298298

299299

300+
- Fixed NaN errors in progress bars when training with iterable datasets with no length defined ([#7306](https://github.com/PyTorchLightning/pytorch-lightning/pull/7306))
301+
302+
303+
- Fixed validation being skipped for iterable datasets with no length defined ([#7306](https://github.com/PyTorchLightning/pytorch-lightning/pull/7306))
304+
305+
300306
- Fixed attaching train and validation dataloaders when `reload_dataloaders_every_epoch=True` and `num_sanity_val_steps=0` ([#7207](https://github.com/PyTorchLightning/pytorch-lightning/pull/7207))
301307

302308

pytorch_lightning/callbacks/model_checkpoint.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -231,9 +231,7 @@ def on_train_batch_end(
231231
self.save_checkpoint(trainer)
232232

233233
def on_validation_end(self, trainer, pl_module) -> None:
234-
"""
235-
checkpoints can be saved at the end of the val loop
236-
"""
234+
""" Save a checkpoint at the end of the validation stage. """
237235
skip = (
238236
self._should_skip_saving_checkpoint(trainer) or self._every_n_val_epochs < 1
239237
or (trainer.current_epoch + 1) % self._every_n_val_epochs != 0

pytorch_lightning/callbacks/progress.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
"""
2121
import importlib
2222
import io
23+
import math
2324
import os
2425
import sys
2526

@@ -397,7 +398,7 @@ def on_train_epoch_start(self, trainer, pl_module):
397398
super().on_train_epoch_start(trainer, pl_module)
398399
total_train_batches = self.total_train_batches
399400
total_val_batches = self.total_val_batches
400-
if total_train_batches != float('inf'):
401+
if total_train_batches != float('inf') and total_val_batches != float('inf'):
401402
# val can be checked multiple times per epoch
402403
val_checks_per_epoch = total_train_batches // trainer.val_check_batch
403404
total_val_batches = total_val_batches * val_checks_per_epoch
@@ -407,7 +408,9 @@ def on_train_epoch_start(self, trainer, pl_module):
407408

408409
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
409410
super().on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx)
410-
if self._should_update(self.train_batch_idx, self.total_train_batches + self.total_val_batches):
411+
total_batches = self.total_train_batches + self.total_val_batches
412+
total_batches = convert_inf(total_batches)
413+
if self._should_update(self.train_batch_idx, total_batches):
411414
self._update_bar(self.main_progress_bar)
412415
self.main_progress_bar.set_postfix(trainer.progress_bar_dict)
413416

@@ -422,7 +425,7 @@ def on_validation_start(self, trainer, pl_module):
422425

423426
def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
424427
super().on_validation_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx)
425-
if self._should_update(self.val_batch_idx, self.total_val_batches):
428+
if self._should_update(self.val_batch_idx, convert_inf(self.total_val_batches)):
426429
self._update_bar(self.val_progress_bar)
427430
self._update_bar(self.main_progress_bar)
428431

@@ -479,7 +482,7 @@ def print(
479482
s = sep.join(map(str, args))
480483
active_progress_bar.write(s, end=end, file=file, nolock=nolock)
481484

482-
def _should_update(self, current, total):
485+
def _should_update(self, current, total) -> bool:
483486
return self.is_enabled and (current % self.refresh_rate == 0 or current == total)
484487

485488
def _update_bar(self, bar: Optional[tqdm]) -> None:
@@ -496,8 +499,8 @@ def _update_bar(self, bar: Optional[tqdm]) -> None:
496499

497500

498501
def convert_inf(x: Optional[Union[int, float]]) -> Optional[Union[int, float]]:
499-
""" The tqdm doesn't support inf values. We have to convert it to None. """
500-
if x == float('inf'):
502+
""" The tqdm doesn't support inf/nan values. We have to convert it to None. """
503+
if x is None or math.isinf(x) or math.isnan(x):
501504
return None
502505
return x
503506

pytorch_lightning/trainer/training_loop.py

Lines changed: 29 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -484,9 +484,9 @@ def run_training_epoch(self):
484484
self.trainer.logger_connector.log_train_step_metrics(batch_output)
485485

486486
# -----------------------------------------
487-
# VALIDATE IF NEEDED + CHECKPOINT CALLBACK
487+
# VALIDATE IF NEEDED
488488
# -----------------------------------------
489-
should_check_val = self.should_check_val_fx(batch_idx, is_last_batch)
489+
should_check_val = self._should_check_val_fx(batch_idx, is_last_batch)
490490
if should_check_val:
491491
self.trainer.validating = True
492492
self.trainer.run_evaluation()
@@ -535,7 +535,7 @@ def run_training_epoch(self):
535535
# log epoch metrics
536536
self.trainer.logger_connector.log_train_epoch_end_metrics(epoch_output)
537537

538-
should_check_val = self.should_check_val_fx(batch_idx, is_last_batch, on_epoch=True)
538+
should_check_val = self._should_check_val_fx(batch_idx, is_last_batch, on_epoch=True)
539539
should_skip_eval = self.trainer.evaluation_loop.should_skip_evaluation(self.trainer.num_val_batches)
540540
should_train_only = self.trainer.disable_validation or should_skip_eval
541541

@@ -825,19 +825,34 @@ def should_accumulate(self):
825825
is_final_batch = self._num_training_batches_reached()
826826
return not (accumulation_done or is_final_batch)
827827

828-
def should_check_val_fx(self, batch_idx, is_last_batch, on_epoch=False):
829-
# decide if we should run validation
830-
is_val_check_batch = (batch_idx + 1) % self.trainer.val_check_batch == 0
831-
is_val_check_epoch = (self.trainer.current_epoch + 1) % self.trainer.check_val_every_n_epoch == 0
832-
can_check_val = self.trainer.enable_validation and is_val_check_epoch
833-
is_last_batch_for_infinite_dataset = is_last_batch and self.trainer.val_check_batch == float("inf")
834-
epoch_end_val_check = (batch_idx + 1) % self.trainer.num_training_batches == 0
828+
def _should_check_val_fx(self, batch_idx: int, is_last_batch: bool, on_epoch: bool = False) -> bool:
829+
""" Decide if we should run validation. """
830+
831+
if not self.trainer.enable_validation:
832+
return False
833+
834+
# check if this epoch is eligible to run validation
835+
if (self.trainer.current_epoch + 1) % self.trainer.check_val_every_n_epoch != 0:
836+
return False
835837

836-
should_check_val = ((is_val_check_batch and epoch_end_val_check) or self.trainer.should_stop
837-
or is_last_batch_for_infinite_dataset
838-
) if on_epoch else (is_val_check_batch and not epoch_end_val_check)
838+
# val_check_batch is inf for iterable datasets with no length defined
839+
# TODO: let training/eval loop handle logic around limit_*_batches and val_check_batch
840+
is_val_check_batch = False
841+
if isinstance(self.trainer.limit_train_batches, int) and self.trainer.val_check_batch == float('inf'):
842+
is_val_check_batch = (batch_idx + 1) % self.trainer.limit_train_batches == 0
843+
elif self.trainer.val_check_batch != float('inf'):
844+
is_val_check_batch = (batch_idx + 1) % self.trainer.val_check_batch == 0
839845

840-
return should_check_val and can_check_val
846+
# Note: num_training_batches is also inf for iterable datasets with no length defined
847+
epoch_end_val_check = (batch_idx + 1) % self.trainer.num_training_batches == 0
848+
is_last_batch_for_infinite_dataset = is_last_batch and self.trainer.val_check_batch == float("inf")
849+
850+
if on_epoch:
851+
return (
852+
is_val_check_batch and epoch_end_val_check
853+
) or self.trainer.should_stop or is_last_batch_for_infinite_dataset
854+
else:
855+
return is_val_check_batch and not epoch_end_val_check
841856

842857
def build_train_args(self, batch, batch_idx, opt_idx, hiddens):
843858
# enable not needing to add opt_idx to training_step

tests/helpers/boring_model.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from typing import Optional
1515

1616
import torch
17-
from torch.utils.data import DataLoader, Dataset, Subset
17+
from torch.utils.data import DataLoader, Dataset, IterableDataset, Subset
1818

1919
from pytorch_lightning import LightningDataModule, LightningModule
2020

@@ -60,6 +60,31 @@ def __len__(self):
6060
return self.len
6161

6262

63+
class RandomIterableDataset(IterableDataset):
64+
65+
def __init__(self, size: int, count: int):
66+
self.count = count
67+
self.size = size
68+
69+
def __iter__(self):
70+
for _ in range(self.count):
71+
yield torch.randn(self.size)
72+
73+
74+
class RandomIterableDatasetWithLen(IterableDataset):
75+
76+
def __init__(self, size: int, count: int):
77+
self.count = count
78+
self.size = size
79+
80+
def __iter__(self):
81+
for _ in range(len(self)):
82+
yield torch.randn(self.size)
83+
84+
def __len__(self):
85+
return self.count
86+
87+
6388
class BoringModel(LightningModule):
6489

6590
def __init__(self):

0 commit comments

Comments
 (0)