File tree Expand file tree Collapse file tree 2 files changed +15
-2
lines changed
pytorch_lightning/plugins/training_type Expand file tree Collapse file tree 2 files changed +15
-2
lines changed Original file line number Diff line number Diff line change @@ -138,7 +138,10 @@ def is_distributed(self) -> bool:
138
138
139
139
def process_dataloader (self , dataloader : DataLoader ) -> MpDeviceLoader :
140
140
TPUSpawnPlugin ._validate_dataloader (dataloader )
141
- return MpDeviceLoader (dataloader , self .root_device )
141
+ dataloader = MpDeviceLoader (dataloader , self .root_device )
142
+ # Mimic interface to torch.utils.data.DataLoader
143
+ dataloader .dataset = dataloader ._loader .dataset
144
+ return dataloader
142
145
143
146
def configure_ddp (self ) -> None :
144
147
pass
Original file line number Diff line number Diff line change 18
18
import pytest
19
19
import torch
20
20
from torch import nn
21
+ from torch .utils .data import DataLoader
21
22
22
23
from pytorch_lightning import Trainer
23
24
from pytorch_lightning .accelerators .cpu import CPUAccelerator
24
25
from pytorch_lightning .accelerators .tpu import TPUAccelerator
25
26
from pytorch_lightning .plugins import TPUPrecisionPlugin , TPUSpawnPlugin , XLACheckpointIO
26
27
from pytorch_lightning .utilities import find_shared_parameters
27
28
from pytorch_lightning .utilities .exceptions import MisconfigurationException
28
- from tests .helpers .boring_model import BoringModel
29
+ from tests .helpers .boring_model import BoringModel , RandomDataset
29
30
from tests .helpers .runif import RunIf
30
31
from tests .helpers .utils import pl_multi_process_test
31
32
@@ -300,3 +301,12 @@ def test_tpu_invalid_raises():
300
301
def test_xla_checkpoint_plugin_being_default ():
301
302
trainer = Trainer (tpu_cores = 8 )
302
303
assert isinstance (trainer .training_type_plugin .checkpoint_io , XLACheckpointIO )
304
+
305
+
306
+ @RunIf (tpu = True )
307
+ @patch ("pytorch_lightning.plugins.training_type.tpu_spawn.xm" )
308
+ def test_mp_device_dataloader_attribute (_ ):
309
+ dataset = RandomDataset (32 , 64 )
310
+ dataloader = TPUSpawnPlugin ().process_dataloader (DataLoader (dataset ))
311
+
312
+ assert dataloader .dataset == dataset
You can’t perform that action at this time.
0 commit comments