|
24 | 24 | from torch.utils.data import DataLoader, DistributedSampler, Sampler
|
25 | 25 |
|
26 | 26 | from pytorch_lightning.lite import LightningLite
|
27 |
| -from pytorch_lightning.lite.wrappers import _LiteDataLoader, _LiteModule, _LiteOptimizer |
| 27 | +from pytorch_lightning.lite.wrappers import ( |
| 28 | + _LiteDataLoader, |
| 29 | + _LiteModule, |
| 30 | + _LiteOptimizer, |
| 31 | + _replace_dataloader_init_method, |
| 32 | +) |
28 | 33 | from pytorch_lightning.plugins import DeepSpeedPlugin, PrecisionPlugin, TrainingTypePlugin
|
29 | 34 | from pytorch_lightning.utilities import DistributedType
|
30 | 35 | from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
@@ -192,57 +197,6 @@ def run(self):
|
192 | 197 | LiteWithCustomDataLoader().run()
|
193 | 198 |
|
194 | 199 |
|
195 |
| -def test_setup_custom_dataloaders(): |
196 |
| - """Test that the setup_dataloaders method returns the dataloaders wrapped as LiteDataLoader.""" |
197 |
| - lite = EmptyLite() |
198 |
| - |
199 |
| - class CustomDataLoader(DataLoader): |
200 |
| - def __init__(self, value: int = 2, *args, **kwargs): |
201 |
| - self.value = value |
202 |
| - super().__init__(range(value), *args, **kwargs) |
203 |
| - |
204 |
| - dataloader = CustomDataLoader(2, batch_size=2) |
205 |
| - |
206 |
| - # single dataloader |
207 |
| - lite_dataloader = lite.setup_dataloaders(dataloader) |
208 |
| - assert lite_dataloader._dataloader |
209 |
| - assert lite_dataloader.value == 2 |
210 |
| - batch0 = next(iter(lite_dataloader)) |
211 |
| - assert torch.equal(batch0, torch.tensor([0, 1])) |
212 |
| - |
213 |
| - class CustomDataLoader2(DataLoader): |
214 |
| - def __init__(self, range, *args, **kwargs): |
215 |
| - self.range = range |
216 |
| - super().__init__(range, *args, **kwargs) |
217 |
| - |
218 |
| - dataloader = CustomDataLoader2(range(2), batch_size=2) |
219 |
| - |
220 |
| - # single dataloader |
221 |
| - lite_dataloader = lite.setup_dataloaders(dataloader) |
222 |
| - assert lite_dataloader._dataloader |
223 |
| - batch0 = next(iter(lite_dataloader)) |
224 |
| - assert torch.equal(batch0, torch.tensor([0, 1])) |
225 |
| - |
226 |
| - class CustomDataLoader(DataLoader): |
227 |
| - def __init__(self, value: int, *args, **kwargs): |
228 |
| - super().__init__(range(value), *args, **kwargs) |
229 |
| - |
230 |
| - class LiteWithCustomDataLoader(LightningLite): |
231 |
| - def run(self): |
232 |
| - # This doesn't fail as the context manager would save all the arguments provided |
233 |
| - # to the dataloaders. |
234 |
| - dataloader = CustomDataLoader(2, batch_size=2) |
235 |
| - self.setup_dataloaders(dataloader) |
236 |
| - |
237 |
| - LiteWithCustomDataLoader().run() |
238 |
| - |
239 |
| - with pytest.raises( |
240 |
| - MisconfigurationException, match="Trying to inject `DistributedSampler` into the `CustomDataLoader` instance" |
241 |
| - ): |
242 |
| - dataloader = CustomDataLoader(2, batch_size=2) |
243 |
| - lite_dataloader = lite.setup_dataloaders(dataloader) |
244 |
| - |
245 |
| - |
246 | 200 | def test_setup_dataloaders_twice_fails():
|
247 | 201 | """Test that calling setup_dataloaders with a dataloader that is already wrapped fails."""
|
248 | 202 | lite = EmptyLite()
|
@@ -490,3 +444,25 @@ def run(self):
|
490 | 444 | assert self.is_global_zero == (self.local_rank == 0)
|
491 | 445 |
|
492 | 446 | Lite(strategy=DeepSpeedPlugin(stage=3, logging_batch_size_per_gpu=1), devices=2, accelerator="gpu").run()
|
| 447 | + |
| 448 | + |
| 449 | +def test_replace_dataloader_init_method(): |
| 450 | + """Test that the context manager enables to save the parameters passed to the DataLoader __init__ method.""" |
| 451 | + |
| 452 | + class CustomDataLoader(DataLoader): |
| 453 | + def __init__(self, extra_argument: int, *args, **kwargs): |
| 454 | + super().__init__(*args, **kwargs) |
| 455 | + |
| 456 | + dataloader = CustomDataLoader(extra_argument=1, dataset=range(1)) |
| 457 | + lite = EmptyLite() |
| 458 | + with pytest.raises(MisconfigurationException, match="extra_argument"): |
| 459 | + dataloader = lite.setup_dataloaders(dataloader) |
| 460 | + |
| 461 | + with _replace_dataloader_init_method(): |
| 462 | + dataloader = CustomDataLoader(extra_argument=1, dataset=range(1)) |
| 463 | + assert dataloader.extra_argument == 1 |
| 464 | + dataloader = lite.setup_dataloaders(dataloader) |
| 465 | + |
| 466 | + dataloader = CustomDataLoader(1, range(1)) |
| 467 | + assert dataloader.extra_argument == 1 |
| 468 | + dataloader = lite.setup_dataloaders(dataloader) |
0 commit comments