|
20 | 20 | from torch.utils.data.sampler import BatchSampler, Sampler, SequentialSampler
|
21 | 21 |
|
22 | 22 | from pytorch_lightning import Trainer
|
| 23 | +from pytorch_lightning.trainer.states import RunningStage |
23 | 24 | from pytorch_lightning.utilities.enums import DistributedType
|
24 | 25 | from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
25 | 26 | from tests.helpers import BoringModel, RandomDataset
|
@@ -279,7 +280,7 @@ class CustomSampler(Sampler):
|
279 | 280 |
|
280 | 281 | # Should raise an error if existing sampler is being replaced
|
281 | 282 | dataloader = CustomDataLoader(dataset, sampler=CustomSampler(dataset))
|
282 |
| - with pytest.raises(MisconfigurationException, match="will be replaced by `DistributedSampler`"): |
| 283 | + with pytest.raises(MisconfigurationException, match="will be replaced by `DistributedSampler`"): |
283 | 284 | trainer.prepare_dataloader(dataloader, shuffle=True)
|
284 | 285 |
|
285 | 286 |
|
@@ -348,3 +349,18 @@ def test_pre_made_batches():
|
348 | 349 | loader = DataLoader(RandomDataset(32, 10), batch_size=None)
|
349 | 350 | trainer = Trainer(fast_dev_run=1)
|
350 | 351 | trainer.predict(LoaderTestModel(), loader)
|
| 352 | + |
| 353 | + |
| 354 | +def test_error_raised_with_float_limited_eval_batches(): |
| 355 | + """Test that an error is raised if there are not enough batches when passed with float value of |
| 356 | + limit_eval_batches.""" |
| 357 | + model = BoringModel() |
| 358 | + dl_size = len(model.val_dataloader()) |
| 359 | + limit_val_batches = 1 / (dl_size + 2) |
| 360 | + trainer = Trainer(limit_val_batches=limit_val_batches) |
| 361 | + trainer._data_connector.attach_data(model) |
| 362 | + with pytest.raises( |
| 363 | + MisconfigurationException, |
| 364 | + match=fr"{limit_val_batches} \* {dl_size} < 1. Please increase the `limit_val_batches`", |
| 365 | + ): |
| 366 | + trainer._reset_eval_dataloader(RunningStage.VALIDATING, model) |
0 commit comments