Skip to content

Commit 82d7d50

Browse files
rohitgr7lexierule
authored andcommitted
Fix the num_batches value in warning (#10980)
1 parent ccfd1d8 commit 82d7d50

File tree

2 files changed

+20
-4
lines changed

2 files changed

+20
-4
lines changed

pytorch_lightning/trainer/data_loading.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,7 @@ def _resolve_sampler(self, dataloader: DataLoader, shuffle: bool, mode: Optional
177177
if self._requires_distributed_sampler(dataloader):
178178
if not isinstance(dataloader.sampler, (SequentialSampler, RandomSampler)):
179179
raise MisconfigurationException(
180-
"You seem to have configured a sampler in your DataLoader. This will be replaced "
180+
"You seem to have configured a sampler in your DataLoader. This will be replaced"
181181
" by `DistributedSampler` since `replace_sampler_ddp` is True and you are using"
182182
" distributed training. Either remove the sampler from your DataLoader or set"
183183
" `replace_sampler_ddp=False` if you want to use your custom sampler."
@@ -478,7 +478,7 @@ def _reset_eval_dataloader(
478478
module = model or self.lightning_module or self.datamodule
479479
if len(dataloaders) != 0:
480480
for i, dataloader in enumerate(dataloaders):
481-
num_batches = (
481+
orig_num_batches = num_batches = (
482482
len(dataloader)
483483
if has_len_all_ranks(dataloader, self.training_type_plugin, module)
484484
else float("inf")
@@ -504,7 +504,7 @@ def _reset_eval_dataloader(
504504
min_pct = 1.0 / len(dataloader)
505505
raise MisconfigurationException(
506506
f"you requested to check {limit_eval_batches} of the `{mode.dataloader_prefix}_dataloader` but"
507-
f" {limit_eval_batches}*{num_batches} < 1. Please increase the"
507+
f" {limit_eval_batches} * {orig_num_batches} < 1. Please increase the"
508508
f" `limit_{mode.dataloader_prefix}_batches` flag. Try at least"
509509
f" `limit_{mode.dataloader_prefix}_batches={min_pct}`"
510510
)

tests/trainer/test_data_loading.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from torch.utils.data.sampler import BatchSampler, Sampler, SequentialSampler
2121

2222
from pytorch_lightning import Trainer
23+
from pytorch_lightning.trainer.states import RunningStage
2324
from pytorch_lightning.utilities.enums import DistributedType
2425
from pytorch_lightning.utilities.exceptions import MisconfigurationException
2526
from tests.helpers import BoringModel, RandomDataset
@@ -279,7 +280,7 @@ class CustomSampler(Sampler):
279280

280281
# Should raise an error if existing sampler is being replaced
281282
dataloader = CustomDataLoader(dataset, sampler=CustomSampler(dataset))
282-
with pytest.raises(MisconfigurationException, match="will be replaced by `DistributedSampler`"):
283+
with pytest.raises(MisconfigurationException, match="will be replaced by `DistributedSampler`"):
283284
trainer.prepare_dataloader(dataloader, shuffle=True)
284285

285286

@@ -348,3 +349,18 @@ def test_pre_made_batches():
348349
loader = DataLoader(RandomDataset(32, 10), batch_size=None)
349350
trainer = Trainer(fast_dev_run=1)
350351
trainer.predict(LoaderTestModel(), loader)
352+
353+
354+
def test_error_raised_with_float_limited_eval_batches():
355+
"""Test that an error is raised if there are not enough batches when passed with float value of
356+
limit_eval_batches."""
357+
model = BoringModel()
358+
dl_size = len(model.val_dataloader())
359+
limit_val_batches = 1 / (dl_size + 2)
360+
trainer = Trainer(limit_val_batches=limit_val_batches)
361+
trainer._data_connector.attach_data(model)
362+
with pytest.raises(
363+
MisconfigurationException,
364+
match=fr"{limit_val_batches} \* {dl_size} < 1. Please increase the `limit_val_batches`",
365+
):
366+
trainer._reset_eval_dataloader(RunningStage.VALIDATING, model)

0 commit comments

Comments
 (0)