Skip to content

Commit ef3360e

Browse files
carmoccaawaelchli
authored andcommitted
Revert part of #10279 (#10376)
1 parent b9a8f74 commit ef3360e

File tree

2 files changed

+35
-62
lines changed

2 files changed

+35
-62
lines changed

pytorch_lightning/lite/lite.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -238,18 +238,15 @@ def _setup_dataloader(
238238
)
239239
sampler = self._get_distributed_sampler(dataloader, **self._strategy.distributed_sampler_kwargs)
240240

241-
dataloader_kwargs = TrainerDataLoadingMixin._get_dataloader_init_kwargs(dataloader, sampler)
242-
try:
243-
dataloader = type(dataloader)(**dataloader_kwargs)
244-
except TypeError:
245-
dataloader_kwargs.pop("dataset")
246-
dataloader = type(dataloader)(**dataloader_kwargs)
241+
# the dataloader needs to be re-instantiated because we want to update the input arguments (e.g., sampler)
242+
dataloader = TrainerDataLoadingMixin._update_dataloader(dataloader, sampler)
243+
247244
# add worker_init_fn for correct seeding in worker processes
248245
TrainerDataLoadingMixin._auto_add_worker_init_fn(dataloader, self.global_rank)
249-
return _LiteDataLoader(
250-
dataloader=self._strategy.process_dataloader(dataloader),
251-
device=self.device if move_to_device and not isinstance(self._strategy, TPUSpawnPlugin) else None,
252-
)
246+
247+
dataloader = self._strategy.process_dataloader(dataloader)
248+
device = self.device if move_to_device and not isinstance(self._strategy, TPUSpawnPlugin) else None
249+
return _LiteDataLoader(dataloader=dataloader, device=device)
253250

254251
def backward(self, tensor: Tensor, *args: Any, model: Optional[_LiteModule] = None, **kwargs: Any) -> None:
255252
"""Replaces ``loss.backward()`` in your training loop. Handles precision and automatically for you.

tests/lite/test_lite.py

Lines changed: 28 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,12 @@
2424
from torch.utils.data import DataLoader, DistributedSampler, Sampler
2525

2626
from pytorch_lightning.lite import LightningLite
27-
from pytorch_lightning.lite.wrappers import _LiteDataLoader, _LiteModule, _LiteOptimizer
27+
from pytorch_lightning.lite.wrappers import (
28+
_LiteDataLoader,
29+
_LiteModule,
30+
_LiteOptimizer,
31+
_replace_dataloader_init_method,
32+
)
2833
from pytorch_lightning.plugins import DeepSpeedPlugin, PrecisionPlugin, TrainingTypePlugin
2934
from pytorch_lightning.utilities import DistributedType
3035
from pytorch_lightning.utilities.exceptions import MisconfigurationException
@@ -192,57 +197,6 @@ def run(self):
192197
LiteWithCustomDataLoader().run()
193198

194199

195-
def test_setup_custom_dataloaders():
196-
"""Test that the setup_dataloaders method returns the dataloaders wrapped as LiteDataLoader."""
197-
lite = EmptyLite()
198-
199-
class CustomDataLoader(DataLoader):
200-
def __init__(self, value: int = 2, *args, **kwargs):
201-
self.value = value
202-
super().__init__(range(value), *args, **kwargs)
203-
204-
dataloader = CustomDataLoader(2, batch_size=2)
205-
206-
# single dataloader
207-
lite_dataloader = lite.setup_dataloaders(dataloader)
208-
assert lite_dataloader._dataloader
209-
assert lite_dataloader.value == 2
210-
batch0 = next(iter(lite_dataloader))
211-
assert torch.equal(batch0, torch.tensor([0, 1]))
212-
213-
class CustomDataLoader2(DataLoader):
214-
def __init__(self, range, *args, **kwargs):
215-
self.range = range
216-
super().__init__(range, *args, **kwargs)
217-
218-
dataloader = CustomDataLoader2(range(2), batch_size=2)
219-
220-
# single dataloader
221-
lite_dataloader = lite.setup_dataloaders(dataloader)
222-
assert lite_dataloader._dataloader
223-
batch0 = next(iter(lite_dataloader))
224-
assert torch.equal(batch0, torch.tensor([0, 1]))
225-
226-
class CustomDataLoader(DataLoader):
227-
def __init__(self, value: int, *args, **kwargs):
228-
super().__init__(range(value), *args, **kwargs)
229-
230-
class LiteWithCustomDataLoader(LightningLite):
231-
def run(self):
232-
# This doesn't fail as the context manager would save all the arguments provided
233-
# to the dataloaders.
234-
dataloader = CustomDataLoader(2, batch_size=2)
235-
self.setup_dataloaders(dataloader)
236-
237-
LiteWithCustomDataLoader().run()
238-
239-
with pytest.raises(
240-
MisconfigurationException, match="Trying to inject `DistributedSampler` into the `CustomDataLoader` instance"
241-
):
242-
dataloader = CustomDataLoader(2, batch_size=2)
243-
lite_dataloader = lite.setup_dataloaders(dataloader)
244-
245-
246200
def test_setup_dataloaders_twice_fails():
247201
"""Test that calling setup_dataloaders with a dataloader that is already wrapped fails."""
248202
lite = EmptyLite()
@@ -490,3 +444,25 @@ def run(self):
490444
assert self.is_global_zero == (self.local_rank == 0)
491445

492446
Lite(strategy=DeepSpeedPlugin(stage=3, logging_batch_size_per_gpu=1), devices=2, accelerator="gpu").run()
447+
448+
449+
def test_replace_dataloader_init_method():
450+
"""Test that the context manager enables to save the parameters passed to the DataLoader __init__ method."""
451+
452+
class CustomDataLoader(DataLoader):
453+
def __init__(self, extra_argument: int, *args, **kwargs):
454+
super().__init__(*args, **kwargs)
455+
456+
dataloader = CustomDataLoader(extra_argument=1, dataset=range(1))
457+
lite = EmptyLite()
458+
with pytest.raises(MisconfigurationException, match="extra_argument"):
459+
dataloader = lite.setup_dataloaders(dataloader)
460+
461+
with _replace_dataloader_init_method():
462+
dataloader = CustomDataLoader(extra_argument=1, dataset=range(1))
463+
assert dataloader.extra_argument == 1
464+
dataloader = lite.setup_dataloaders(dataloader)
465+
466+
dataloader = CustomDataLoader(1, range(1))
467+
assert dataloader.extra_argument == 1
468+
dataloader = lite.setup_dataloaders(dataloader)

0 commit comments

Comments
 (0)