Skip to content

Commit ce3e632

Browse files
peterdudfieldpre-commit-ci[bot]carmocca
authored
Fix failure when DataLoader(batch_size=None) is passed (#10345)
* add test, + add change to data loading batch sample method * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Refactor and CHANGELOG Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Carlos Mocholí <[email protected]>
1 parent a255dcb commit ce3e632

File tree

3 files changed

+25
-17
lines changed

3 files changed

+25
-17
lines changed

CHANGELOG.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
7777
- Fixed `apply_to_collection(defaultdict)` ([#10316](https://github.com/PyTorchLightning/pytorch-lightning/issues/10316))
7878

7979

80-
-
80+
- Fixed failure when `DataLoader(batch_size=None)` is passed ([#10345](https://github.com/PyTorchLightning/pytorch-lightning/issues/10345))
8181

8282

8383
-

pytorch_lightning/trainer/data_loading.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,7 @@ def _dataloader_init_kwargs_resolve_sampler(
184184
batch_sampler = getattr(dataloader, "batch_sampler")
185185
is_predicting = mode == RunningStage.PREDICTING
186186
# checking the batch sampler type is different than PyTorch default.
187-
if (batch_sampler is not None and type(batch_sampler) is not BatchSampler) or is_predicting:
187+
if batch_sampler is not None and (type(batch_sampler) is not BatchSampler or is_predicting):
188188
batch_sampler = type(batch_sampler)(
189189
sampler,
190190
batch_size=batch_sampler.batch_size,

tests/trainer/test_data_loading.py

Lines changed: 23 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -283,25 +283,26 @@ class CustomSampler(Sampler):
283283
trainer.prepare_dataloader(dataloader, shuffle=True)
284284

285285

286-
def test_loader_detaching():
287-
"""Checks that the loader has been resetted after the entrypoint."""
286+
class LoaderTestModel(BoringModel):
287+
def training_step(self, batch, batch_idx):
288+
assert len(self.trainer.train_dataloader.loaders) == 10
289+
return super().training_step(batch, batch_idx)
288290

289-
class LoaderTestModel(BoringModel):
290-
def training_step(self, batch, batch_idx):
291-
assert len(self.trainer.train_dataloader.loaders) == 10
292-
return super().training_step(batch, batch_idx)
291+
def validation_step(self, batch, batch_idx):
292+
assert len(self.trainer.val_dataloaders[0]) == 10
293+
return super().validation_step(batch, batch_idx)
293294

294-
def validation_step(self, batch, batch_idx):
295-
assert len(self.trainer.val_dataloaders[0]) == 10
296-
return super().validation_step(batch, batch_idx)
295+
def test_step(self, batch, batch_idx):
296+
assert len(self.trainer.test_dataloaders[0]) == 10
297+
return super().test_step(batch, batch_idx)
297298

298-
def test_step(self, batch, batch_idx):
299-
assert len(self.trainer.test_dataloaders[0]) == 10
300-
return super().test_step(batch, batch_idx)
299+
def predict_step(self, batch, batch_idx, dataloader_idx=0):
300+
assert len(self.trainer.predict_dataloaders[0]) == 10
301+
return super().predict_step(batch, batch_idx, dataloader_idx=dataloader_idx)
301302

302-
def predict_step(self, batch, batch_idx, dataloader_idx=0):
303-
assert len(self.trainer.predict_dataloaders[0]) == 10
304-
return super().predict_step(batch, batch_idx, dataloader_idx=dataloader_idx)
303+
304+
def test_loader_detaching():
305+
"""Checks that the loader has been resetted after the entrypoint."""
305306

306307
loader = DataLoader(RandomDataset(32, 10), batch_size=1)
307308

@@ -340,3 +341,10 @@ def predict_step(self, batch, batch_idx, dataloader_idx=0):
340341
assert len(model.val_dataloader()) == 64
341342
assert len(model.predict_dataloader()) == 64
342343
assert len(model.test_dataloader()) == 64
344+
345+
346+
def test_pre_made_batches():
347+
"""Check that loader works with pre-made batches."""
348+
loader = DataLoader(RandomDataset(32, 10), batch_size=None)
349+
trainer = Trainer(fast_dev_run=1)
350+
trainer.predict(LoaderTestModel(), loader)

0 commit comments

Comments
 (0)