Skip to content

Commit a5ce5a7

Browse files
committed
fix tests
1 parent 4fc502a commit a5ce5a7

File tree

2 files changed

+11
-8
lines changed

2 files changed

+11
-8
lines changed

pytorch_lightning/trainer/trainer.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@
5858
SimpleProfiler,
5959
XLAProfiler,
6060
)
61-
from pytorch_lightning.strategies import ParallelStrategy, SingleDeviceStrategy, Strategy
61+
from pytorch_lightning.strategies import ParallelStrategy, Strategy
6262
from pytorch_lightning.strategies.ddp_spawn import DDPSpawnStrategy
6363
from pytorch_lightning.trainer.callback_hook import TrainerCallbackHookMixin
6464
from pytorch_lightning.trainer.configuration_validator import verify_loop_configurations
@@ -2028,11 +2028,14 @@ def num_nodes(self) -> int:
20282028

20292029
@property
20302030
def device_ids(self) -> List[int]:
2031-
if isinstance(self.strategy, ParallelStrategy):
2032-
return [torch._utils._get_device_index(device, allow_cpu=True) for device in self.strategy.parallel_devices]
2033-
elif isinstance(self.strategy, SingleDeviceStrategy):
2034-
return [torch._utils._get_device_index(self.strategy.root_device, allow_cpu=True)]
2035-
return []
2031+
devices = getattr(self.strategy, "parallel_devices", [self.strategy.root_device])
2032+
device_ids = []
2033+
for idx, device in enumerate(devices):
2034+
if isinstance(device, torch.device):
2035+
device_ids.append(device.index or idx)
2036+
elif isinstance(device, int):
2037+
device_ids.append(device)
2038+
return device_ids
20362039

20372040
@property
20382041
def num_devices(self) -> int:

tests/trainer/test_trainer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2151,7 +2151,7 @@ def test_dataloaders_are_not_loaded_if_disabled_through_limit_batches(running_st
21512151
@pytest.mark.parametrize(
21522152
["trainer_kwargs", "expected_device_ids"],
21532153
[
2154-
({"strategy": None}, []),
2154+
({"strategy": None}, [0]),
21552155
({"num_processes": 1}, [0]),
21562156
({"gpus": 1}, [0]),
21572157
({"devices": 1}, [0]),
@@ -2166,4 +2166,4 @@ def test_trainer_config_device_ids(monkeypatch, trainer_kwargs, expected_device_
21662166
monkeypatch.setattr(torch.cuda, "is_available", lambda: True)
21672167
monkeypatch.setattr(torch.cuda, "device_count", lambda: 4)
21682168
trainer = Trainer(**trainer_kwargs)
2169-
trainer.num_devices = expected_device_ids
2169+
assert trainer.device_ids == expected_device_ids

0 commit comments

Comments
 (0)