Skip to content

Commit de52d05

Browse files
committed
add device_ids and num_devices to depreacate Trainer.devices
1 parent 0fe3379 commit de52d05

File tree

7 files changed

+57
-14
lines changed

7 files changed

+57
-14
lines changed

CHANGELOG.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
140140
- Added support for pluggable Accelerators ([#12030](https://github.com/PyTorchLightning/pytorch-lightning/pull/12030))
141141

142142

143+
- Added `device_ids` and `num_devices` property to `Trainer` ([#12151](https://github.com/PyTorchLightning/pytorch-lightning/pull/12151))
144+
145+
143146
### Changed
144147

145148
- Make `benchmark` flag optional and set its value based on the deterministic flag ([#11944](https://github.com/PyTorchLightning/pytorch-lightning/pull/11944))
@@ -448,6 +451,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
448451
- Deprecated `BaseProfiler.profile_iterable` ([#12102](https://github.com/PyTorchLightning/pytorch-lightning/pull/12102))
449452

450453

454+
- Deprecated `Trainer.devices` in favor of `Trainer.num_devices` and `Trainer.device_ids` ([#12151](https://github.com/PyTorchLightning/pytorch-lightning/pull/12151))
455+
456+
451457
### Removed
452458

453459
- Removed deprecated parameter `method` in `pytorch_lightning.utilities.model_helpers.is_overridden` ([#10507](https://github.com/PyTorchLightning/pytorch-lightning/pull/10507))

pytorch_lightning/trainer/trainer.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@
5858
SimpleProfiler,
5959
XLAProfiler,
6060
)
61-
from pytorch_lightning.strategies import ParallelStrategy, Strategy
61+
from pytorch_lightning.strategies import ParallelStrategy, SingleDeviceStrategy, Strategy
6262
from pytorch_lightning.strategies.ddp_spawn import DDPSpawnStrategy
6363
from pytorch_lightning.trainer.callback_hook import TrainerCallbackHookMixin
6464
from pytorch_lightning.trainer.configuration_validator import verify_loop_configurations
@@ -2026,6 +2026,18 @@ def should_rank_save_checkpoint(self) -> bool:
20262026
def num_nodes(self) -> int:
20272027
return getattr(self.strategy, "num_nodes", 1)
20282028

2029+
@property
2030+
def device_ids(self) -> List[int]:
2031+
if isinstance(self.strategy, ParallelStrategy):
2032+
return [torch._utils._get_device_index(device, allow_cpu=True) for device in self.strategy.parallel_devices]
2033+
elif isinstance(self.strategy, SingleDeviceStrategy):
2034+
return [torch._utils._get_device_index(self.strategy.root_device, allow_cpu=True)]
2035+
return []
2036+
2037+
@property
2038+
def num_devices(self) -> int:
2039+
return len(self.device_ids)
2040+
20292041
@property
20302042
def num_processes(self) -> int:
20312043
return self._accelerator_connector.num_processes
@@ -2048,7 +2060,11 @@ def num_gpus(self) -> int:
20482060

20492061
@property
20502062
def devices(self) -> Optional[Union[List[int], str, int]]:
2051-
return self._accelerator_connector.devices
2063+
rank_zero_deprecation(
2064+
"`Trainer.devices` was deprecated in v1.6 and will be removed in v1.8."
2065+
" Please use `Trainer.num_devices` or `Trainer.device_ids` to get device information instead."
2066+
)
2067+
return self.num_devices
20522068

20532069
@property
20542070
def data_parallel_device_ids(self) -> Optional[List[int]]:

tests/accelerators/test_accelerator_connector.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -571,22 +571,22 @@ def test_validate_accelerator_and_devices():
571571
def test_set_devices_if_none_cpu():
572572

573573
trainer = Trainer(accelerator="cpu", num_processes=3)
574-
assert trainer.devices == 3
574+
assert trainer.num_devices == 3
575575

576576

577577
@RunIf(min_gpus=2)
578578
def test_set_devices_if_none_gpu():
579579

580580
trainer = Trainer(accelerator="gpu", gpus=2)
581-
assert trainer.devices == 2
581+
assert trainer.num_devices == 2
582582

583583

584584
def test_devices_with_cpu_only_supports_integer():
585585

586586
with pytest.warns(UserWarning, match="The flag `devices` must be an int"):
587587
trainer = Trainer(accelerator="cpu", devices="1,3")
588588
assert isinstance(trainer.accelerator, CPUAccelerator)
589-
assert trainer.devices == 1
589+
assert trainer.num_devices == 1
590590

591591

592592
@pytest.mark.parametrize("training_type", ["ddp2", "dp"])
@@ -931,15 +931,15 @@ def test_unsupported_ipu_choice(monkeypatch):
931931
@mock.patch("pytorch_lightning.utilities.imports._IPU_AVAILABLE", return_value=False)
932932
def test_devices_auto_choice_cpu(is_ipu_available_mock, is_tpu_available_mock, is_gpu_available_mock):
933933
trainer = Trainer(accelerator="auto", devices="auto")
934-
assert trainer.devices == 1
934+
assert trainer.num_devices == 1
935935
assert trainer.num_processes == 1
936936

937937

938938
@mock.patch("torch.cuda.is_available", return_value=True)
939939
@mock.patch("torch.cuda.device_count", return_value=2)
940940
def test_devices_auto_choice_gpu(is_gpu_available_mock, device_count_mock):
941941
trainer = Trainer(accelerator="auto", devices="auto")
942-
assert trainer.devices == 2
942+
assert trainer.num_devices == 2
943943
assert trainer.gpus == 2
944944

945945

tests/accelerators/test_ipu.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -398,7 +398,7 @@ def test_manual_poptorch_opts(tmpdir):
398398
dataloader = trainer.train_dataloader.loaders
399399
assert isinstance(dataloader, poptorch.DataLoader)
400400
assert dataloader.options == training_opts
401-
assert trainer.devices > 1 # testing this only makes sense in a distributed setting
401+
assert trainer.num_devices > 1 # testing this only makes sense in a distributed setting
402402
assert not isinstance(dataloader.sampler, DistributedSampler)
403403

404404

@@ -586,7 +586,7 @@ def test_accelerator_ipu_with_ipus_priority():
586586
def test_set_devices_if_none_ipu():
587587

588588
trainer = Trainer(accelerator="ipu", ipus=8)
589-
assert trainer.devices == 8
589+
assert trainer.num_devices == 8
590590

591591

592592
@RunIf(ipu=True)
@@ -629,5 +629,5 @@ def test_poptorch_models_at_different_stages(tmpdir):
629629
@RunIf(ipu=True)
630630
def test_devices_auto_choice_ipu():
631631
trainer = Trainer(accelerator="auto", devices="auto")
632-
assert trainer.devices == 4
632+
assert trainer.num_devices == 4
633633
assert trainer.ipus == 4

tests/accelerators/test_tpu.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ def test_accelerator_tpu(accelerator, devices):
101101
trainer = Trainer(accelerator=accelerator, devices=devices)
102102
assert isinstance(trainer.accelerator, TPUAccelerator)
103103
assert isinstance(trainer.strategy, TPUSpawnStrategy)
104-
assert trainer.devices == 8
104+
assert trainer.num_devices == 8
105105
assert trainer.tpu_cores == 8
106106

107107

@@ -120,7 +120,7 @@ def test_accelerator_tpu_with_tpu_cores_priority():
120120
def test_set_devices_if_none_tpu():
121121

122122
trainer = Trainer(accelerator="tpu", tpu_cores=8)
123-
assert trainer.devices == 8
123+
assert trainer.num_devices == 8
124124

125125

126126
@RunIf(tpu=True)

tests/trainer/flags/test_env_vars.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,6 @@ def test_passing_env_variables_defaults():
5151
def test_passing_env_variables_devices(cuda_available_mock, device_count_mock):
5252
"""Testing overwriting trainer arguments."""
5353
trainer = Trainer()
54-
assert trainer.devices == 2
54+
assert trainer.num_devices == 2
5555
trainer = Trainer(accelerator="gpu", devices=1)
56-
assert trainer.devices == 1
56+
assert trainer.num_devices == 1

tests/trainer/test_trainer.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2146,3 +2146,24 @@ def test_dataloaders_are_not_loaded_if_disabled_through_limit_batches(running_st
21462146
else getattr(trainer, f"{dl_prefix}_dataloaders")
21472147
)
21482148
assert dl is None
2149+
2150+
2151+
@pytest.mark.parametrize(
2152+
["trainer_kwargs", "expected_device_ids"],
2153+
[
2154+
({"strategy": None}, []),
2155+
({"num_processes": 1}, [0]),
2156+
({"gpus": 1}, [0]),
2157+
({"devices": 1}, [0]),
2158+
({"strategy": "ddp", "devices": 1}, [0]),
2159+
({"strategy": "ddp", "gpus": 2}, [0, 1]),
2160+
({"strategy": "ddp", "num_processes": 2}, [0, 1]),
2161+
({"strategy": "ddp", "gpus": [0, 2]}, [0, 2]),
2162+
],
2163+
)
2164+
def test_trainer_config_device_ids(monkeypatch, trainer_kwargs, expected_device_ids):
2165+
if trainer_kwargs.get("gpus") is not None:
2166+
monkeypatch.setattr(torch.cuda, "is_available", lambda: True)
2167+
monkeypatch.setattr(torch.cuda, "device_count", lambda: 4)
2168+
trainer = Trainer(**trainer_kwargs)
2169+
trainer.num_devices = expected_device_ids

0 commit comments

Comments
 (0)