|
13 | 13 | # limitations under the License.
|
14 | 14 | import pytest
|
15 | 15 | import torch
|
16 |
| -from torch.utils.data.sampler import Sampler, SequentialSampler |
| 16 | +from torch.utils.data import DataLoader, RandomSampler, Sampler, SequentialSampler |
17 | 17 |
|
18 | 18 | from pytorch_lightning import Trainer
|
| 19 | +from pytorch_lightning.trainer.states import RunningStage |
| 20 | +from tests.base.model_template import EvalModelTemplate |
19 | 21 | from tests.helpers.boring_model import BoringModel, RandomDataset
|
20 | 22 |
|
21 | 23 |
|
@@ -62,3 +64,128 @@ def train_dataloader(self):
|
62 | 64 | trainer.fit(model)
|
63 | 65 |
|
64 | 66 | assert isinstance(trainer.train_dataloader.loaders.sampler, SequentialSampler)
|
| 67 | + |
| 68 | + |
| 69 | +def test_overfit_batch_limits(tmpdir): |
| 70 | + # ------------------------------------------------------ |
| 71 | + # Make sure shuffle is correct across loaders initially |
| 72 | + # ------------------------------------------------------ |
| 73 | + model = EvalModelTemplate() |
| 74 | + model.train_dataloader() |
| 75 | + |
| 76 | + # original train loader which should be replaced in all methods |
| 77 | + train_loader = model.train_dataloader() |
| 78 | + |
| 79 | + # make sure the val and tests are not shuffled |
| 80 | + assert isinstance(train_loader.sampler, RandomSampler) |
| 81 | + assert isinstance(model.val_dataloader().sampler, SequentialSampler) |
| 82 | + assert isinstance(model.test_dataloader().sampler, SequentialSampler) |
| 83 | + |
| 84 | + # ------------------------------------------------------ |
| 85 | + # get the training loader and batch |
| 86 | + # ------------------------------------------------------ |
| 87 | + # Create a reference train dataloader without shuffling. |
| 88 | + train_loader = DataLoader(model.train_dataloader().dataset, shuffle=False) |
| 89 | + (xa, ya) = next(iter(train_loader)) |
| 90 | + train_loader = DataLoader(model.train_dataloader().dataset, shuffle=True) |
| 91 | + full_train_samples = len(train_loader) |
| 92 | + num_train_samples = int(0.11 * full_train_samples) |
| 93 | + |
| 94 | + # ------------------------------------------------------ |
| 95 | + # set VAL and Test loaders |
| 96 | + # ------------------------------------------------------ |
| 97 | + val_loader = DataLoader(model.val_dataloader().dataset, shuffle=False) |
| 98 | + test_loader = DataLoader(model.test_dataloader().dataset, shuffle=False) |
| 99 | + |
| 100 | + # set the model loaders |
| 101 | + model.train_dataloader = lambda: train_loader |
| 102 | + model.val_dataloader = lambda: val_loader |
| 103 | + model.test_dataloader = lambda: test_loader |
| 104 | + |
| 105 | + # ------------------------------------------------------ |
| 106 | + # test train loader applies correct limits |
| 107 | + # ------------------------------------------------------ |
| 108 | + trainer = Trainer(overfit_batches=4) |
| 109 | + model.trainer = trainer |
| 110 | + trainer._data_connector.attach_dataloaders(model=model) |
| 111 | + trainer.reset_train_dataloader(model) |
| 112 | + assert trainer.num_training_batches == 4 |
| 113 | + |
| 114 | + # make sure the loaders are the same |
| 115 | + (xb, yb) = next(iter(trainer.train_dataloader)) |
| 116 | + assert torch.eq(xa, xb).all() |
| 117 | + assert torch.eq(ya, yb).all() |
| 118 | + |
| 119 | + trainer = Trainer(overfit_batches=0.11) |
| 120 | + model.trainer = trainer |
| 121 | + trainer._data_connector.attach_dataloaders(model=model) |
| 122 | + trainer.reset_train_dataloader(model) |
| 123 | + # The dataloader should have been overwritten with a Sequential sampler. |
| 124 | + assert trainer.train_dataloader is not train_loader |
| 125 | + assert trainer.num_training_batches == num_train_samples |
| 126 | + |
| 127 | + # make sure the loaders are the same |
| 128 | + (xb, yb) = next(iter(trainer.train_dataloader)) |
| 129 | + assert torch.eq(xa, xb).all() |
| 130 | + assert torch.eq(ya, yb).all() |
| 131 | + |
| 132 | + # ------------------------------------------------------ |
| 133 | + # run tests for both val and test |
| 134 | + # ------------------------------------------------------ |
| 135 | + for split in (RunningStage.VALIDATING, RunningStage.TESTING): |
| 136 | + |
| 137 | + # ------------------------------------------------------ |
| 138 | + # test overfit_batches as percent |
| 139 | + # ------------------------------------------------------ |
| 140 | + trainer = Trainer(overfit_batches=0.11) |
| 141 | + trainer._data_connector.attach_dataloaders(model) |
| 142 | + loader_num_batches, _ = trainer._reset_eval_dataloader(split, model=model) |
| 143 | + if split == RunningStage.VALIDATING: |
| 144 | + assert loader_num_batches[0] == 0 |
| 145 | + else: |
| 146 | + assert loader_num_batches[0] == len(test_loader) |
| 147 | + |
| 148 | + # ------------------------------------------------------ |
| 149 | + # test overfit_batches as int |
| 150 | + # ------------------------------------------------------ |
| 151 | + trainer = Trainer(overfit_batches=1) |
| 152 | + trainer._data_connector.attach_dataloaders(model) |
| 153 | + loader_num_batches, dataloaders = trainer._reset_eval_dataloader(split, model=model) |
| 154 | + if split == RunningStage.VALIDATING: |
| 155 | + assert loader_num_batches[0] == 0 |
| 156 | + else: |
| 157 | + assert loader_num_batches[0] == len(test_loader) |
| 158 | + # make sure we turned off shuffle for the user |
| 159 | + assert isinstance(dataloaders[0].sampler, SequentialSampler) |
| 160 | + |
| 161 | + trainer = Trainer(overfit_batches=5) |
| 162 | + trainer._data_connector.attach_dataloaders(model) |
| 163 | + loader_num_batches, _ = trainer._reset_eval_dataloader(split, model=model) |
| 164 | + if split == RunningStage.VALIDATING: |
| 165 | + assert loader_num_batches[0] == 0 |
| 166 | + else: |
| 167 | + assert loader_num_batches[0] == len(test_loader) |
| 168 | + |
| 169 | + # ------------------------------------------------------ |
| 170 | + # test limit_xxx_batches as percent AND int |
| 171 | + # ------------------------------------------------------ |
| 172 | + if split == RunningStage.VALIDATING: |
| 173 | + trainer = Trainer(limit_val_batches=0.1) |
| 174 | + trainer._data_connector.attach_dataloaders(model) |
| 175 | + loader_num_batches, dataloaders = trainer._reset_eval_dataloader(split, model=model) |
| 176 | + assert loader_num_batches[0] == int(0.1 * len(val_loader)) |
| 177 | + |
| 178 | + trainer = Trainer(limit_val_batches=10) |
| 179 | + trainer._data_connector.attach_dataloaders(model) |
| 180 | + loader_num_batches, dataloaders = trainer._reset_eval_dataloader(split, model=model) |
| 181 | + assert loader_num_batches[0] == 10 |
| 182 | + else: |
| 183 | + trainer = Trainer(limit_test_batches=0.1) |
| 184 | + trainer._data_connector.attach_dataloaders(model) |
| 185 | + loader_num_batches, dataloaders = trainer._reset_eval_dataloader(split, model=model) |
| 186 | + assert loader_num_batches[0] == int(0.1 * len(test_loader)) |
| 187 | + |
| 188 | + trainer = Trainer(limit_test_batches=10) |
| 189 | + trainer._data_connector.attach_dataloaders(model) |
| 190 | + loader_num_batches, dataloaders = trainer._reset_eval_dataloader(split, model=model) |
| 191 | + assert loader_num_batches[0] == 10 |
0 commit comments