Skip to content

Commit b469397

Browse files
committed
avoid calling xm.xla_device() on TPUSpawnStrategy
1 parent 7e82dfa commit b469397

File tree

2 files changed

+4
-7
lines changed

2 files changed

+4
-7
lines changed

pytorch_lightning/trainer/trainer.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@
5959
SimpleProfiler,
6060
XLAProfiler,
6161
)
62-
from pytorch_lightning.strategies import ParallelStrategy, Strategy
62+
from pytorch_lightning.strategies import ParallelStrategy, SingleDeviceStrategy, Strategy
6363
from pytorch_lightning.strategies.ddp_spawn import DDPSpawnStrategy
6464
from pytorch_lightning.trainer.callback_hook import TrainerCallbackHookMixin
6565
from pytorch_lightning.trainer.configuration_validator import verify_loop_configurations
@@ -2056,7 +2056,7 @@ def num_nodes(self) -> int:
20562056
@property
20572057
def device_ids(self) -> List[int]:
20582058
"""List of device indexes per node."""
2059-
devices = getattr(self.strategy, "parallel_devices", [self.strategy.root_device])
2059+
devices = [self.strategy.root_device] if isinstance(self.strategy, SingleDeviceStrategy) else self.strategy.parallel_devices
20602060
device_ids = []
20612061
for idx, device in enumerate(devices):
20622062
if isinstance(device, torch.device):
@@ -2068,7 +2068,7 @@ def device_ids(self) -> List[int]:
20682068
@property
20692069
def num_devices(self) -> int:
20702070
"""Number of devices the trainer uses per node."""
2071-
return len(self.device_ids)
2071+
return 1 if isinstance(self.strategy, SingleDeviceStrategy) else len(self.device_ids)
20722072

20732073
@property
20742074
def num_processes(self) -> int:

tests/accelerators/test_tpu.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,6 @@ def test_accelerator_cpu_with_tpu_cores_flag():
9494

9595

9696
@RunIf(tpu=True)
97-
@pl_multi_process_test
9897
@pytest.mark.parametrize(["accelerator", "devices"], [("auto", 8), ("auto", "auto"), ("tpu", None)])
9998
def test_accelerator_tpu(accelerator, devices):
10099
assert TPUAccelerator.is_available()
@@ -104,13 +103,12 @@ def test_accelerator_tpu(accelerator, devices):
104103
assert isinstance(trainer.strategy, TPUSpawnStrategy)
105104
assert trainer.num_devices == 8
106105
with pytest.deprecated_call(
107-
match= "`Trainer.tpu_cores` is deprecated in v1.6 and will be removed in v1.8. "
106+
match="`Trainer.tpu_cores` is deprecated in v1.6 and will be removed in v1.8. "
108107
"Please use `Trainer.devices` instead."
109108
):
110109
trainer.tpu_cores == 8
111110

112111

113-
114112
@RunIf(tpu=True)
115113
def test_accelerator_tpu_with_tpu_cores_priority():
116114
"""Test for checking `tpu_cores` flag takes priority over `devices`."""
@@ -124,7 +122,6 @@ def test_accelerator_tpu_with_tpu_cores_priority():
124122

125123

126124
@RunIf(tpu=True)
127-
@pl_multi_process_test
128125
def test_set_devices_if_none_tpu():
129126
trainer = Trainer(accelerator="tpu", tpu_cores=8)
130127
assert trainer.num_devices == 8

0 commit comments

Comments
 (0)