@@ -493,11 +493,12 @@ def __len__(self) -> int:
493
493
num_batches = 0
494
494
not_implemented_count = 0
495
495
496
- def get_num_batches (dataloader : DataLoader ) -> None :
496
+ def get_num_batches (dataloader : DataLoader , name : str ) -> None :
497
497
nonlocal num_batches
498
498
if not has_len (dataloader ):
499
499
rank_zero_warn (
500
- "The number of batches for a dataloader is counted as 0 because it does not have `__len__` defined."
500
+ f"The number of batches for a dataloader in `{ name } ` is counted as 0 "
501
+ "because it does not have `__len__` defined."
501
502
)
502
503
else :
503
504
num_batches += len (dataloader )
@@ -506,11 +507,12 @@ def get_num_batches(dataloader: DataLoader) -> None:
506
507
dataloader_method = getattr (self , method_name )
507
508
try :
508
509
dataloader = dataloader_method ()
509
- if isinstance (dataloader , CombinedLoader ):
510
- dataloader = dataloader .loaders
511
- apply_to_collection (dataloader , DataLoader , get_num_batches )
512
510
except NotImplementedError :
513
511
not_implemented_count += 1
512
+ continue
513
+ if isinstance (dataloader , CombinedLoader ):
514
+ dataloader = dataloader .loaders
515
+ apply_to_collection (dataloader , DataLoader , get_num_batches , method_name )
514
516
515
517
if not_implemented_count == 4 :
516
518
rank_zero_warn ("You datamodule does not have any valid dataloader so `__len__` will be returned as 0." )
0 commit comments