|
15 | 15 | import torch
|
16 | 16 | from torch.utils.data import DataLoader, RandomSampler, Sampler, SequentialSampler
|
17 | 17 |
|
| 18 | +from legacy.simple_classif_training import ClassifDataModule, ClassificationModel |
18 | 19 | from pytorch_lightning import Trainer
|
19 | 20 | from pytorch_lightning.trainer.states import RunningStage
|
20 |
| -from tests.base.model_template import EvalModelTemplate |
21 | 21 | from tests.helpers.boring_model import BoringModel, RandomDataset
|
22 | 22 |
|
23 | 23 |
|
@@ -70,32 +70,32 @@ def test_overfit_batch_limits(tmpdir):
|
70 | 70 | # ------------------------------------------------------
|
71 | 71 | # Make sure shuffle is correct across loaders initially
|
72 | 72 | # ------------------------------------------------------
|
73 |
| - model = EvalModelTemplate() |
74 |
| - model.train_dataloader() |
| 73 | + model = ClassificationModel() |
| 74 | + dm = ClassifDataModule() |
75 | 75 |
|
76 | 76 | # original train loader which should be replaced in all methods
|
77 |
| - train_loader = model.train_dataloader() |
| 77 | + train_loader = dm.train_dataloader() |
78 | 78 |
|
79 | 79 | # make sure the val and tests are not shuffled
|
80 | 80 | assert isinstance(train_loader.sampler, RandomSampler)
|
81 |
| - assert isinstance(model.val_dataloader().sampler, SequentialSampler) |
82 |
| - assert isinstance(model.test_dataloader().sampler, SequentialSampler) |
| 81 | + assert isinstance(dm.val_dataloader().sampler, SequentialSampler) |
| 82 | + assert isinstance(dm.test_dataloader().sampler, SequentialSampler) |
83 | 83 |
|
84 | 84 | # ------------------------------------------------------
|
85 | 85 | # get the training loader and batch
|
86 | 86 | # ------------------------------------------------------
|
87 | 87 | # Create a reference train dataloader without shuffling.
|
88 |
| - train_loader = DataLoader(model.train_dataloader().dataset, shuffle=False) |
| 88 | + train_loader = DataLoader(dm.train_dataloader().dataset, shuffle=False) |
89 | 89 | (xa, ya) = next(iter(train_loader))
|
90 |
| - train_loader = DataLoader(model.train_dataloader().dataset, shuffle=True) |
| 90 | + train_loader = DataLoader(dm.train_dataloader().dataset, shuffle=True) |
91 | 91 | full_train_samples = len(train_loader)
|
92 | 92 | num_train_samples = int(0.11 * full_train_samples)
|
93 | 93 |
|
94 | 94 | # ------------------------------------------------------
|
95 | 95 | # set VAL and Test loaders
|
96 | 96 | # ------------------------------------------------------
|
97 |
| - val_loader = DataLoader(model.val_dataloader().dataset, shuffle=False) |
98 |
| - test_loader = DataLoader(model.test_dataloader().dataset, shuffle=False) |
| 97 | + val_loader = DataLoader(dm.val_dataloader().dataset, shuffle=False) |
| 98 | + test_loader = DataLoader(dm.test_dataloader().dataset, shuffle=False) |
99 | 99 |
|
100 | 100 | # set the model loaders
|
101 | 101 | model.train_dataloader = lambda: train_loader
|
@@ -165,27 +165,3 @@ def test_overfit_batch_limits(tmpdir):
|
165 | 165 | assert loader_num_batches[0] == 0
|
166 | 166 | else:
|
167 | 167 | assert loader_num_batches[0] == len(test_loader)
|
168 |
| - |
169 |
| - # ------------------------------------------------------ |
170 |
| - # test limit_xxx_batches as percent AND int |
171 |
| - # ------------------------------------------------------ |
172 |
| - if split == RunningStage.VALIDATING: |
173 |
| - trainer = Trainer(limit_val_batches=0.1) |
174 |
| - trainer._data_connector.attach_dataloaders(model) |
175 |
| - loader_num_batches, dataloaders = trainer._reset_eval_dataloader(split, model=model) |
176 |
| - assert loader_num_batches[0] == int(0.1 * len(val_loader)) |
177 |
| - |
178 |
| - trainer = Trainer(limit_val_batches=10) |
179 |
| - trainer._data_connector.attach_dataloaders(model) |
180 |
| - loader_num_batches, dataloaders = trainer._reset_eval_dataloader(split, model=model) |
181 |
| - assert loader_num_batches[0] == 10 |
182 |
| - else: |
183 |
| - trainer = Trainer(limit_test_batches=0.1) |
184 |
| - trainer._data_connector.attach_dataloaders(model) |
185 |
| - loader_num_batches, dataloaders = trainer._reset_eval_dataloader(split, model=model) |
186 |
| - assert loader_num_batches[0] == int(0.1 * len(test_loader)) |
187 |
| - |
188 |
| - trainer = Trainer(limit_test_batches=10) |
189 |
| - trainer._data_connector.attach_dataloaders(model) |
190 |
| - loader_num_batches, dataloaders = trainer._reset_eval_dataloader(split, model=model) |
191 |
| - assert loader_num_batches[0] == 10 |
0 commit comments