21
21
import pytest
22
22
import torch
23
23
from omegaconf import OmegaConf
24
+ from torch .utils .data import DataLoader
24
25
25
26
from pytorch_lightning import LightningDataModule , Trainer
26
27
from pytorch_lightning .callbacks import ModelCheckpoint
28
+ from pytorch_lightning .trainer .supporters import CombinedLoader
27
29
from pytorch_lightning .utilities import AttributeDict
28
30
from pytorch_lightning .utilities .exceptions import MisconfigurationException
29
31
from pytorch_lightning .utilities .model_helpers import is_overridden
30
- from tests .helpers import BoringDataModule , BoringModel
32
+ from tests .helpers import BoringDataModule , BoringModel , RandomDataset
31
33
from tests .helpers .datamodules import ClassifDataModule
32
34
from tests .helpers .runif import RunIf
33
35
from tests .helpers .simple_models import ClassificationModel
@@ -564,13 +566,14 @@ class BoringDataModule1(LightningDataModule):
564
566
batch_size : int
565
567
dims : int = 2
566
568
567
- def __post_init__ (self ):
568
- super (). __init__ ( dims = self .dims )
569
+ def train_dataloader (self ):
570
+ return DataLoader ( torch . randn ( self . batch_size * 2 , 10 ), batch_size = self .batch_size )
569
571
570
572
# asserts for the different dunder methods added by dataclass, when __init__ is implemented, i.e.
571
573
# __repr__, __eq__, __lt__, __le__, etc.
572
574
assert BoringDataModule1 (batch_size = 64 ).dims == 2
573
575
assert BoringDataModule1 (batch_size = 32 )
576
+ assert len (BoringDataModule1 (batch_size = 32 )) == 2
574
577
assert hasattr (BoringDataModule1 , "__repr__" )
575
578
assert BoringDataModule1 (batch_size = 32 ) == BoringDataModule1 (batch_size = 32 )
576
579
@@ -581,7 +584,9 @@ class BoringDataModule2(LightningDataModule):
581
584
582
585
# asserts for the different dunder methods added by dataclass, when super class is inherently initialized, i.e.
583
586
# __init__, __repr__, __eq__, __lt__, __le__, etc.
584
- assert BoringDataModule2 (batch_size = 32 )
587
+ assert BoringDataModule2 (batch_size = 32 ) is not None
588
+ assert BoringDataModule2 (batch_size = 32 ).batch_size == 32
589
+ assert len (BoringDataModule2 (batch_size = 32 )) == 0
585
590
assert hasattr (BoringDataModule2 , "__repr__" )
586
591
assert BoringDataModule2 (batch_size = 32 ).prepare_data () is None
587
592
assert BoringDataModule2 (batch_size = 32 ) == BoringDataModule2 (batch_size = 32 )
@@ -625,3 +630,69 @@ def test_inconsistent_prepare_data_per_node(tmpdir):
625
630
trainer .model = model
626
631
trainer .datamodule = dm
627
632
trainer .data_connector .prepare_data ()
633
+
634
+
635
+ DATALOADER = DataLoader (RandomDataset (1 , 32 ))
636
+
637
+
638
+ @pytest .mark .parametrize ("method_name" , ["train_dataloader" , "val_dataloader" , "test_dataloader" , "predict_dataloader" ])
639
+ @pytest .mark .parametrize (
640
+ ["dataloader" , "expected" ],
641
+ [
642
+ [DATALOADER , 32 ],
643
+ [[DATALOADER , DATALOADER ], 64 ],
644
+ [[[DATALOADER ], [DATALOADER , DATALOADER ]], 96 ],
645
+ [[{"foo" : DATALOADER }, {"foo" : DATALOADER , "bar" : DATALOADER }], 96 ],
646
+ [{"foo" : DATALOADER , "bar" : DATALOADER }, 64 ],
647
+ [{"foo" : {"foo" : DATALOADER }, "bar" : {"foo" : DATALOADER , "bar" : DATALOADER }}, 96 ],
648
+ [{"foo" : [DATALOADER ], "bar" : [DATALOADER , DATALOADER ]}, 96 ],
649
+ [CombinedLoader ({"foo" : DATALOADER , "bar" : DATALOADER }), 64 ],
650
+ ],
651
+ )
652
+ def test_len_different_types (method_name , dataloader , expected ):
653
+ dm = LightningDataModule ()
654
+ setattr (dm , method_name , lambda : dataloader )
655
+ assert len (dm ) == expected
656
+
657
+
658
+ @pytest .mark .parametrize ("method_name" , ["train_dataloader" , "val_dataloader" , "test_dataloader" , "predict_dataloader" ])
659
+ def test_len_dataloader_no_len (method_name ):
660
+ class CustomNotImplementedErrorDataloader (DataLoader ):
661
+ def __len__ (self ):
662
+ raise NotImplementedError
663
+
664
+ dataloader = CustomNotImplementedErrorDataloader (RandomDataset (1 , 32 ))
665
+ dm = LightningDataModule ()
666
+ setattr (dm , method_name , lambda : dataloader )
667
+ with pytest .warns (UserWarning , match = f"The number of batches for a dataloader in `{ method_name } ` is counted as 0" ):
668
+ assert len (dm ) == 0
669
+
670
+
671
+ def test_len_all_dataloader_methods_implemented ():
672
+ class BoringDataModule (LightningDataModule ):
673
+ def __init__ (self , dataloader ):
674
+ super ().__init__ ()
675
+ self .dataloader = dataloader
676
+
677
+ def train_dataloader (self ):
678
+ return {"foo" : self .dataloader , "bar" : self .dataloader }
679
+
680
+ def val_dataloader (self ):
681
+ return self .dataloader
682
+
683
+ def test_dataloader (self ):
684
+ return [self .dataloader ]
685
+
686
+ def predict_dataloader (self ):
687
+ return [self .dataloader , self .dataloader ]
688
+
689
+ dm = BoringDataModule (DATALOADER )
690
+
691
+ # 6 dataloaders each producing 32 batches: 6 * 32 = 192
692
+ assert len (dm ) == 192
693
+
694
+
695
+ def test_len_no_dataloader_methods_implemented ():
696
+ dm = LightningDataModule ()
697
+ with pytest .warns (UserWarning , match = "You datamodule does not have any valid dataloader" ):
698
+ assert len (dm ) == 0
0 commit comments