Skip to content

Commit ed2bcc5

Browse files
authored
Deprecate Trainer.devices in favor of Trainer.num_devices and Trainer.device_ids (#12151)
1 parent 09d1296 commit ed2bcc5

File tree

8 files changed

+103
-15
lines changed

8 files changed

+103
-15
lines changed

CHANGELOG.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
152152
- Added support to explicitly specify the process group backend for parallel strategies ([#11745](https://github.com/PyTorchLightning/pytorch-lightning/pull/11745))
153153

154154

155+
- Added `device_ids` and `num_devices` property to `Trainer` ([#12151](https://github.com/PyTorchLightning/pytorch-lightning/pull/12151))
156+
157+
155158
### Changed
156159

157160
- Drop PyTorch 1.7 support ([#12191](https://github.com/PyTorchLightning/pytorch-lightning/pull/12191))
@@ -518,6 +521,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
518521
- Deprecated `ParallelPlugin.torch_distributed_backend` in favor of `DDPStrategy.process_group_backend` property ([#11745](https://github.com/PyTorchLightning/pytorch-lightning/pull/11745))
519522

520523

524+
- Deprecated `Trainer.devices` in favor of `Trainer.num_devices` and `Trainer.device_ids` ([#12151](https://github.com/PyTorchLightning/pytorch-lightning/pull/12151))
525+
526+
521527
### Removed
522528

523529
- Removed deprecated parameter `method` in `pytorch_lightning.utilities.model_helpers.is_overridden` ([#10507](https://github.com/PyTorchLightning/pytorch-lightning/pull/10507))

pytorch_lightning/trainer/trainer.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2010,6 +2010,23 @@ def should_rank_save_checkpoint(self) -> bool:
20102010
def num_nodes(self) -> int:
20112011
return getattr(self.strategy, "num_nodes", 1)
20122012

2013+
@property
2014+
def device_ids(self) -> List[int]:
2015+
"""List of device indexes per node."""
2016+
devices = getattr(self.strategy, "parallel_devices", [self.strategy.root_device])
2017+
device_ids = []
2018+
for idx, device in enumerate(devices):
2019+
if isinstance(device, torch.device):
2020+
device_ids.append(device.index or idx)
2021+
elif isinstance(device, int):
2022+
device_ids.append(device)
2023+
return device_ids
2024+
2025+
@property
2026+
def num_devices(self) -> int:
2027+
"""Number of devices the trainer uses per node."""
2028+
return len(self.device_ids)
2029+
20132030
@property
20142031
def num_processes(self) -> int:
20152032
return self._accelerator_connector.num_processes
@@ -2031,8 +2048,12 @@ def num_gpus(self) -> int:
20312048
return self._accelerator_connector.num_gpus
20322049

20332050
@property
2034-
def devices(self) -> Optional[Union[List[int], str, int]]:
2035-
return self._accelerator_connector.devices
2051+
def devices(self) -> int:
2052+
rank_zero_deprecation(
2053+
"`Trainer.devices` was deprecated in v1.6 and will be removed in v1.8."
2054+
" Please use `Trainer.num_devices` or `Trainer.device_ids` to get device information instead."
2055+
)
2056+
return self.num_devices
20362057

20372058
@property
20382059
def data_parallel_device_ids(self) -> Optional[List[int]]:

tests/accelerators/test_accelerator_connector.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -579,22 +579,22 @@ def test_validate_accelerator_and_devices():
579579
def test_set_devices_if_none_cpu():
580580

581581
trainer = Trainer(accelerator="cpu", num_processes=3)
582-
assert trainer.devices == 3
582+
assert trainer.num_devices == 3
583583

584584

585585
@RunIf(min_gpus=2)
586586
def test_set_devices_if_none_gpu():
587587

588588
trainer = Trainer(accelerator="gpu", gpus=2)
589-
assert trainer.devices == 2
589+
assert trainer.num_devices == 2
590590

591591

592592
def test_devices_with_cpu_only_supports_integer():
593593

594594
with pytest.warns(UserWarning, match="The flag `devices` must be an int"):
595595
trainer = Trainer(accelerator="cpu", devices="1,3")
596596
assert isinstance(trainer.accelerator, CPUAccelerator)
597-
assert trainer.devices == 1
597+
assert trainer.num_devices == 1
598598

599599

600600
@pytest.mark.parametrize("training_type", ["ddp2", "dp"])
@@ -941,15 +941,15 @@ def test_unsupported_ipu_choice(mock_ipu_acc_avail, monkeypatch):
941941
@mock.patch("pytorch_lightning.utilities.imports._IPU_AVAILABLE", return_value=False)
942942
def test_devices_auto_choice_cpu(is_ipu_available_mock, is_tpu_available_mock, is_gpu_available_mock):
943943
trainer = Trainer(accelerator="auto", devices="auto")
944-
assert trainer.devices == 1
944+
assert trainer.num_devices == 1
945945
assert trainer.num_processes == 1
946946

947947

948948
@mock.patch("torch.cuda.is_available", return_value=True)
949949
@mock.patch("torch.cuda.device_count", return_value=2)
950950
def test_devices_auto_choice_gpu(is_gpu_available_mock, device_count_mock):
951951
trainer = Trainer(accelerator="auto", devices="auto")
952-
assert trainer.devices == 2
952+
assert trainer.num_devices == 2
953953
assert trainer.gpus == 2
954954

955955

tests/accelerators/test_ipu.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -400,7 +400,7 @@ def test_manual_poptorch_opts(tmpdir):
400400
dataloader = trainer.train_dataloader.loaders
401401
assert isinstance(dataloader, poptorch.DataLoader)
402402
assert dataloader.options == training_opts
403-
assert trainer.devices > 1 # testing this only makes sense in a distributed setting
403+
assert trainer.num_devices > 1 # testing this only makes sense in a distributed setting
404404
assert not isinstance(dataloader.sampler, DistributedSampler)
405405

406406

@@ -588,7 +588,7 @@ def test_accelerator_ipu_with_ipus_priority():
588588
def test_set_devices_if_none_ipu():
589589

590590
trainer = Trainer(accelerator="ipu", ipus=8)
591-
assert trainer.devices == 8
591+
assert trainer.num_devices == 8
592592

593593

594594
@RunIf(ipu=True)
@@ -631,5 +631,5 @@ def test_poptorch_models_at_different_stages(tmpdir):
631631
@RunIf(ipu=True)
632632
def test_devices_auto_choice_ipu():
633633
trainer = Trainer(accelerator="auto", devices="auto")
634-
assert trainer.devices == 4
634+
assert trainer.num_devices == 4
635635
assert trainer.ipus == 4

tests/accelerators/test_tpu.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -94,14 +94,15 @@ def test_accelerator_cpu_with_tpu_cores_flag():
9494

9595

9696
@RunIf(tpu=True)
97+
@pl_multi_process_test
9798
@pytest.mark.parametrize(["accelerator", "devices"], [("auto", 8), ("auto", "auto"), ("tpu", None)])
9899
def test_accelerator_tpu(accelerator, devices):
99100
assert TPUAccelerator.is_available()
100101

101102
trainer = Trainer(accelerator=accelerator, devices=devices)
102103
assert isinstance(trainer.accelerator, TPUAccelerator)
103104
assert isinstance(trainer.strategy, TPUSpawnStrategy)
104-
assert trainer.devices == 8
105+
assert trainer.num_devices == 8
105106
assert trainer.tpu_cores == 8
106107

107108

@@ -117,10 +118,10 @@ def test_accelerator_tpu_with_tpu_cores_priority():
117118

118119

119120
@RunIf(tpu=True)
121+
@pl_multi_process_test
120122
def test_set_devices_if_none_tpu():
121-
122123
trainer = Trainer(accelerator="tpu", tpu_cores=8)
123-
assert trainer.devices == 8
124+
assert trainer.num_devices == 8
124125

125126

126127
@RunIf(tpu=True)
@@ -310,3 +311,21 @@ def test_mp_device_dataloader_attribute(_):
310311
def test_warning_if_tpus_not_used():
311312
with pytest.warns(UserWarning, match="TPU available but not used. Set `accelerator` and `devices`"):
312313
Trainer()
314+
315+
316+
@pytest.mark.skip(reason="TODO(@kaushikb11): Optimize TPU tests to avoid timeouts")
317+
@RunIf(tpu=True)
318+
@pytest.mark.parametrize(
319+
["devices", "expected_device_ids"],
320+
[
321+
(1, [0]),
322+
(8, list(range(8))),
323+
("8", list(range(8))),
324+
([2], [2]),
325+
("2,", [2]),
326+
],
327+
)
328+
def test_trainer_config_device_ids(devices, expected_device_ids):
329+
trainer = Trainer(accelerator="tpu", devices=devices)
330+
assert trainer.device_ids == expected_device_ids
331+
assert trainer.num_devices == len(expected_device_ids)

tests/deprecated_api/test_remove_1-8.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -878,3 +878,12 @@ def all_gather(self, tensor):
878878
match="ParallelStrategy.torch_distributed_backend was deprecated" " in v1.6 and will be removed in v1.8."
879879
):
880880
strategy.torch_distributed_backend
881+
882+
883+
def test_trainer_config_device_ids():
884+
trainer = Trainer(devices=2)
885+
with pytest.deprecated_call(
886+
match="`Trainer.devices` was deprecated in v1.6 and will be removed in v1.8."
887+
" Please use `Trainer.num_devices` or `Trainer.device_ids` to get device information instead."
888+
):
889+
trainer.devices == 2

tests/trainer/flags/test_env_vars.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,6 @@ def test_passing_env_variables_defaults():
5151
def test_passing_env_variables_devices(cuda_available_mock, device_count_mock):
5252
"""Testing overwriting trainer arguments."""
5353
trainer = Trainer()
54-
assert trainer.devices == 2
54+
assert trainer.num_devices == 2
5555
trainer = Trainer(accelerator="gpu", devices=1)
56-
assert trainer.devices == 1
56+
assert trainer.num_devices == 1

tests/trainer/test_trainer.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from torch.optim import SGD
3131
from torch.utils.data import DataLoader, IterableDataset
3232

33+
import pytorch_lightning
3334
import tests.helpers.utils as tutils
3435
from pytorch_lightning import Callback, LightningDataModule, LightningModule, Trainer
3536
from pytorch_lightning.accelerators import CPUAccelerator, GPUAccelerator
@@ -2117,3 +2118,35 @@ def test_dataloaders_are_not_loaded_if_disabled_through_limit_batches(running_st
21172118
else getattr(trainer, f"{dl_prefix}_dataloaders")
21182119
)
21192120
assert dl is None
2121+
2122+
2123+
@pytest.mark.parametrize(
2124+
["trainer_kwargs", "expected_device_ids"],
2125+
[
2126+
({}, [0]),
2127+
({"devices": 1}, [0]),
2128+
({"devices": 1}, [0]),
2129+
({"devices": "1"}, [0]),
2130+
({"devices": 2}, [0, 1]),
2131+
({"accelerator": "gpu", "devices": 1}, [0]),
2132+
({"accelerator": "gpu", "devices": 2}, [0, 1]),
2133+
({"accelerator": "gpu", "devices": "2"}, [0, 1]),
2134+
({"accelerator": "gpu", "devices": [2]}, [2]),
2135+
({"accelerator": "gpu", "devices": "2,"}, [2]),
2136+
({"accelerator": "gpu", "devices": [0, 2]}, [0, 2]),
2137+
({"accelerator": "gpu", "devices": "0, 2"}, [0, 2]),
2138+
({"accelerator": "ipu", "devices": 1}, [0]),
2139+
({"accelerator": "ipu", "devices": 2}, [0, 1]),
2140+
],
2141+
)
2142+
def test_trainer_config_device_ids(monkeypatch, trainer_kwargs, expected_device_ids):
2143+
if trainer_kwargs.get("accelerator") == "gpu":
2144+
monkeypatch.setattr(torch.cuda, "is_available", lambda: True)
2145+
monkeypatch.setattr(torch.cuda, "device_count", lambda: 4)
2146+
elif trainer_kwargs.get("accelerator") == "ipu":
2147+
monkeypatch.setattr(pytorch_lightning.accelerators.ipu.IPUAccelerator, "is_available", lambda _: True)
2148+
monkeypatch.setattr(pytorch_lightning.strategies.ipu, "_IPU_AVAILABLE", lambda: True)
2149+
2150+
trainer = Trainer(**trainer_kwargs)
2151+
assert trainer.device_ids == expected_device_ids
2152+
assert trainer.num_devices == len(expected_device_ids)

0 commit comments

Comments
 (0)