Skip to content

Commit c33df26

Browse files
authored
Set dataset attribute to MpDeviceLoader used in TPU Spawn (Lightning-AI#10151)
1 parent 5ade197 commit c33df26

File tree

2 files changed

+15
-2
lines changed

2 files changed

+15
-2
lines changed

pytorch_lightning/plugins/training_type/tpu_spawn.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,10 @@ def is_distributed(self) -> bool:
138138

139139
def process_dataloader(self, dataloader: DataLoader) -> MpDeviceLoader:
140140
TPUSpawnPlugin._validate_dataloader(dataloader)
141-
return MpDeviceLoader(dataloader, self.root_device)
141+
dataloader = MpDeviceLoader(dataloader, self.root_device)
142+
# Mimic interface to torch.utils.data.DataLoader
143+
dataloader.dataset = dataloader._loader.dataset
144+
return dataloader
142145

143146
def configure_ddp(self) -> None:
144147
pass

tests/accelerators/test_tpu.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,15 @@
1818
import pytest
1919
import torch
2020
from torch import nn
21+
from torch.utils.data import DataLoader
2122

2223
from pytorch_lightning import Trainer
2324
from pytorch_lightning.accelerators.cpu import CPUAccelerator
2425
from pytorch_lightning.accelerators.tpu import TPUAccelerator
2526
from pytorch_lightning.plugins import TPUPrecisionPlugin, TPUSpawnPlugin, XLACheckpointIO
2627
from pytorch_lightning.utilities import find_shared_parameters
2728
from pytorch_lightning.utilities.exceptions import MisconfigurationException
28-
from tests.helpers.boring_model import BoringModel
29+
from tests.helpers.boring_model import BoringModel, RandomDataset
2930
from tests.helpers.runif import RunIf
3031
from tests.helpers.utils import pl_multi_process_test
3132

@@ -300,3 +301,12 @@ def test_tpu_invalid_raises():
300301
def test_xla_checkpoint_plugin_being_default():
301302
trainer = Trainer(tpu_cores=8)
302303
assert isinstance(trainer.training_type_plugin.checkpoint_io, XLACheckpointIO)
304+
305+
306+
@RunIf(tpu=True)
307+
@patch("pytorch_lightning.plugins.training_type.tpu_spawn.xm")
308+
def test_mp_device_dataloader_attribute(_):
309+
dataset = RandomDataset(32, 64)
310+
dataloader = TPUSpawnPlugin().process_dataloader(DataLoader(dataset))
311+
312+
assert dataloader.dataset == dataset

0 commit comments

Comments
 (0)