Skip to content

Commit 1f581eb

Browse files
committed
Add dataloader name in warning and wrap only dataloader_method in try
1 parent d961df5 commit 1f581eb

File tree

2 files changed

+8
-6
lines changed

2 files changed

+8
-6
lines changed

pytorch_lightning/core/datamodule.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -493,11 +493,12 @@ def __len__(self) -> int:
493493
num_batches = 0
494494
not_implemented_count = 0
495495

496-
def get_num_batches(dataloader: DataLoader) -> None:
496+
def get_num_batches(dataloader: DataLoader, name: str) -> None:
497497
nonlocal num_batches
498498
if not has_len(dataloader):
499499
rank_zero_warn(
500-
"The number of batches for a dataloader is counted as 0 because it does not have `__len__` defined."
500+
f"The number of batches for a dataloader in `{name}` is counted as 0 "
501+
"because it does not have `__len__` defined."
501502
)
502503
else:
503504
num_batches += len(dataloader)
@@ -506,11 +507,12 @@ def get_num_batches(dataloader: DataLoader) -> None:
506507
dataloader_method = getattr(self, method_name)
507508
try:
508509
dataloader = dataloader_method()
509-
if isinstance(dataloader, CombinedLoader):
510-
dataloader = dataloader.loaders
511-
apply_to_collection(dataloader, DataLoader, get_num_batches)
512510
except NotImplementedError:
513511
not_implemented_count += 1
512+
continue
513+
if isinstance(dataloader, CombinedLoader):
514+
dataloader = dataloader.loaders
515+
apply_to_collection(dataloader, DataLoader, get_num_batches, method_name)
514516

515517
if not_implemented_count == 4:
516518
rank_zero_warn("You datamodule does not have any valid dataloader so `__len__` will be returned as 0.")

tests/core/test_datamodules.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -664,7 +664,7 @@ def __len__(self):
664664
dataloader = CustomNotImplementedErrorDataloader(RandomDataset(1, 32))
665665
dm = LightningDataModule()
666666
setattr(dm, method_name, lambda: dataloader)
667-
with pytest.warns(UserWarning, match="The number of batches for a dataloader is counted as 0"):
667+
with pytest.warns(UserWarning, match=f"The number of batches for a dataloader in `{method_name}` is counted as 0"):
668668
assert len(dm) == 0
669669

670670

0 commit comments

Comments
 (0)