21
21
import pytest
22
22
import torch
23
23
from omegaconf import OmegaConf
24
- from torch .utils .data import DataLoader
25
24
26
25
from pytorch_lightning import LightningDataModule , Trainer
27
26
from pytorch_lightning .callbacks import ModelCheckpoint
28
- from pytorch_lightning .trainer .supporters import CombinedLoader
29
27
from pytorch_lightning .utilities import AttributeDict
30
28
from pytorch_lightning .utilities .exceptions import MisconfigurationException
31
29
from pytorch_lightning .utilities .model_helpers import is_overridden
32
- from tests .helpers import BoringDataModule , BoringModel , RandomDataset
30
+ from tests .helpers import BoringDataModule , BoringModel
33
31
from tests .helpers .datamodules import ClassifDataModule
34
32
from tests .helpers .runif import RunIf
35
33
from tests .helpers .simple_models import ClassificationModel
@@ -566,14 +564,13 @@ class BoringDataModule1(LightningDataModule):
566
564
batch_size : int
567
565
dims : int = 2
568
566
569
- def train_dataloader (self ):
570
- return DataLoader ( torch . randn ( self . batch_size * 2 , 10 ), batch_size = self .batch_size )
567
+ def __post_init__ (self ):
568
+ super (). __init__ ( dims = self .dims )
571
569
572
570
# asserts for the different dunder methods added by dataclass, when __init__ is implemented, i.e.
573
571
# __repr__, __eq__, __lt__, __le__, etc.
574
572
assert BoringDataModule1 (batch_size = 64 ).dims == 2
575
573
assert BoringDataModule1 (batch_size = 32 )
576
- assert len (BoringDataModule1 (batch_size = 32 )) == 2
577
574
assert hasattr (BoringDataModule1 , "__repr__" )
578
575
assert BoringDataModule1 (batch_size = 32 ) == BoringDataModule1 (batch_size = 32 )
579
576
@@ -584,9 +581,7 @@ class BoringDataModule2(LightningDataModule):
584
581
585
582
# asserts for the different dunder methods added by dataclass, when super class is inherently initialized, i.e.
586
583
# __init__, __repr__, __eq__, __lt__, __le__, etc.
587
- assert BoringDataModule2 (batch_size = 32 ) is not None
588
- assert BoringDataModule2 (batch_size = 32 ).batch_size == 32
589
- assert len (BoringDataModule2 (batch_size = 32 )) == 0
584
+ assert BoringDataModule2 (batch_size = 32 )
590
585
assert hasattr (BoringDataModule2 , "__repr__" )
591
586
assert BoringDataModule2 (batch_size = 32 ).prepare_data () is None
592
587
assert BoringDataModule2 (batch_size = 32 ) == BoringDataModule2 (batch_size = 32 )
@@ -630,76 +625,3 @@ def test_inconsistent_prepare_data_per_node(tmpdir):
630
625
trainer .model = model
631
626
trainer .datamodule = dm
632
627
trainer ._data_connector .prepare_data ()
633
-
634
-
635
- DATALOADER = DataLoader (RandomDataset (1 , 32 ))
636
-
637
-
638
- @pytest .mark .parametrize ("method_name" , ["train_dataloader" , "val_dataloader" , "test_dataloader" , "predict_dataloader" ])
639
- @pytest .mark .parametrize (
640
- ["dataloader" , "expected" ],
641
- [
642
- [DATALOADER , 32 ],
643
- [[DATALOADER , DATALOADER ], 64 ],
644
- [[[DATALOADER ], [DATALOADER , DATALOADER ]], 96 ],
645
- [[{"foo" : DATALOADER }, {"foo" : DATALOADER , "bar" : DATALOADER }], 96 ],
646
- [{"foo" : DATALOADER , "bar" : DATALOADER }, 64 ],
647
- [{"foo" : {"foo" : DATALOADER }, "bar" : {"foo" : DATALOADER , "bar" : DATALOADER }}, 96 ],
648
- [{"foo" : [DATALOADER ], "bar" : [DATALOADER , DATALOADER ]}, 96 ],
649
- [CombinedLoader ({"foo" : DATALOADER , "bar" : DATALOADER }), 64 ],
650
- ],
651
- )
652
- def test_len_different_types (method_name , dataloader , expected ):
653
- dm = LightningDataModule ()
654
- setattr (dm , method_name , lambda : dataloader )
655
- assert len (dm ) == expected
656
-
657
-
658
- @pytest .mark .parametrize ("method_name" , ["train_dataloader" , "val_dataloader" , "test_dataloader" , "predict_dataloader" ])
659
- def test_len_dataloader_no_len (method_name ):
660
- class CustomNotImplementedErrorDataloader (DataLoader ):
661
- def __len__ (self ):
662
- raise NotImplementedError
663
-
664
- dataloader = CustomNotImplementedErrorDataloader (RandomDataset (1 , 32 ))
665
- dm = LightningDataModule ()
666
- setattr (dm , method_name , lambda : dataloader )
667
- with pytest .warns (UserWarning , match = f"The number of batches for a dataloader in `{ method_name } ` is counted as 0" ):
668
- assert len (dm ) == 0
669
-
670
-
671
- def test_len_all_dataloader_methods_implemented ():
672
- class BoringDataModule (LightningDataModule ):
673
- def __init__ (self , dataloader ):
674
- super ().__init__ ()
675
- self .dataloader = dataloader
676
-
677
- def train_dataloader (self ):
678
- return {"foo" : self .dataloader , "bar" : self .dataloader }
679
-
680
- def val_dataloader (self ):
681
- return self .dataloader
682
-
683
- def test_dataloader (self ):
684
- return [self .dataloader ]
685
-
686
- def predict_dataloader (self ):
687
- return [self .dataloader , self .dataloader ]
688
-
689
- dm = BoringDataModule (DATALOADER )
690
-
691
- # 6 dataloaders each producing 32 batches: 6 * 32 = 192
692
- assert len (dm ) == 192
693
-
694
-
695
- def test_len_no_dataloader_methods_implemented ():
696
- dm = LightningDataModule ()
697
- with pytest .warns (UserWarning , match = "You datamodule does not have any valid dataloader" ):
698
- assert len (dm ) == 0
699
-
700
- dm .train_dataloader = None
701
- dm .val_dataloader = None
702
- dm .test_dataloader = None
703
- dm .predict_dataloader = None
704
- with pytest .warns (UserWarning , match = "You datamodule does not have any valid dataloader" ):
705
- assert len (dm ) == 0
0 commit comments