Skip to content

Commit 4fba371

Browse files
committed
address comments
1 parent 05a8469 commit 4fba371

File tree

3 files changed

+39
-5
lines changed

3 files changed

+39
-5
lines changed

pytorch_lightning/trainer/trainer.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@
5858
SimpleProfiler,
5959
XLAProfiler,
6060
)
61-
from pytorch_lightning.strategies import ParallelStrategy, Strategy
61+
from pytorch_lightning.strategies import ParallelStrategy, SingleDeviceStrategy, Strategy
6262
from pytorch_lightning.strategies.ddp_spawn import DDPSpawnStrategy
6363
from pytorch_lightning.trainer.callback_hook import TrainerCallbackHookMixin
6464
from pytorch_lightning.trainer.configuration_validator import verify_loop_configurations
@@ -2028,6 +2028,7 @@ def num_nodes(self) -> int:
20282028

20292029
@property
20302030
def device_ids(self) -> List[int]:
2031+
"""List of device indexes per node."""
20312032
devices = getattr(self.strategy, "parallel_devices", [self.strategy.root_device])
20322033
device_ids = []
20332034
for idx, device in enumerate(devices):
@@ -2039,7 +2040,12 @@ def device_ids(self) -> List[int]:
20392040

20402041
@property
20412042
def num_devices(self) -> int:
2042-
return len(self.device_ids)
2043+
"""Number of devices per node."""
2044+
if isinstance(self.strategy, SingleDeviceStrategy):
2045+
return 1
2046+
elif isinstance(self.strategy, ParallelStrategy):
2047+
return len(self.strategy.parallel_devices)
2048+
return 0
20432049

20442050
@property
20452051
def num_processes(self) -> int:

tests/deprecated_api/test_remove_1-8.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -688,3 +688,29 @@ def on_load_checkpoint(self, checkpoint):
688688
" v1.6 and will be removed in v1.8. Use `load_state_dict` instead."
689689
):
690690
trainer.fit(model)
691+
692+
693+
@pytest.mark.parametrize(
694+
["trainer_kwargs", "expected_devices"],
695+
[
696+
({}, 1),
697+
({"devices": 1}, 1),
698+
({"accelerator": "gpu", "devices": 1}, 1),
699+
({"strategy": "ddp", "devices": 1}, 1),
700+
({"strategy": "ddp", "accelerator": "gpu", "devices": 1}, 1),
701+
({"strategy": "ddp", "devices": 2}, 2),
702+
({"strategy": "ddp", "accelerator": "gpu", "devices": 2}, 2),
703+
({"strategy": "ddp", "accelerator": "gpu", "devices": [2]}, 1),
704+
({"strategy": "ddp", "accelerator": "gpu", "devices": [0, 2]}, 2),
705+
],
706+
)
707+
def test_v1_8_0_trainer_devices(monkeypatch, trainer_kwargs, expected_devices):
708+
if trainer_kwargs.get("accelerator") == "gpu":
709+
monkeypatch.setattr(torch.cuda, "is_available", lambda: True)
710+
monkeypatch.setattr(torch.cuda, "device_count", lambda: 4)
711+
trainer = Trainer(**trainer_kwargs)
712+
with pytest.deprecated_call(
713+
match="`Trainer.devices` was deprecated in v1.6 and will be removed in v1.8."
714+
" Please use `Trainer.num_devices` or `Trainer.device_ids` to get device information instead."
715+
):
716+
trainer.devices == expected_devices

tests/trainer/test_trainer.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2151,13 +2151,14 @@ def test_dataloaders_are_not_loaded_if_disabled_through_limit_batches(running_st
21512151
@pytest.mark.parametrize(
21522152
["trainer_kwargs", "expected_device_ids"],
21532153
[
2154-
({"strategy": None}, [0]),
2155-
({"num_processes": 1}, [0]),
2154+
({}, [0]),
21562155
({"devices": 1}, [0]),
21572156
({"accelerator": "gpu", "devices": 1}, [0]),
21582157
({"strategy": "ddp", "devices": 1}, [0]),
2159-
({"strategy": "ddp", "accelerator": "gpu", "devices": 2}, [0, 1]),
2158+
({"strategy": "ddp", "accelerator": "gpu", "devices": 1}, [0]),
21602159
({"strategy": "ddp", "devices": 2}, [0, 1]),
2160+
({"strategy": "ddp", "accelerator": "gpu", "devices": 2}, [0, 1]),
2161+
({"strategy": "ddp", "accelerator": "gpu", "devices": [2]}, [2]),
21612162
({"strategy": "ddp", "accelerator": "gpu", "devices": [0, 2]}, [0, 2]),
21622163
],
21632164
)
@@ -2167,3 +2168,4 @@ def test_trainer_config_device_ids(monkeypatch, trainer_kwargs, expected_device_
21672168
monkeypatch.setattr(torch.cuda, "device_count", lambda: 4)
21682169
trainer = Trainer(**trainer_kwargs)
21692170
assert trainer.device_ids == expected_device_ids
2171+
assert len(trainer.device_ids) == trainer.num_devices

0 commit comments

Comments
 (0)