Skip to content

Commit 01d817c

Browse files
Deprecate Trainer.gpus (#12436)
Co-authored-by: Carlos Mocholí <[email protected]>
1 parent e072336 commit 01d817c

File tree

6 files changed

+39
-5
lines changed

6 files changed

+39
-5
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -600,6 +600,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
600600
- Deprecated passing only the callback state to `Callback.on_load_checkpoint(callback_state)` in favor of passing the callback state to `Callback.load_state_dict` and in 1.8, passing the entire checkpoint dictionary to `Callback.on_load_checkpoint(checkpoint)` ([#11887](https://github.com/PyTorchLightning/pytorch-lightning/pull/11887))
601601

602602

603+
- Deprecated `Trainer.gpus` in favor of `Trainer.device_ids` or `Trainer.num_devices` ([#12436](https://github.com/PyTorchLightning/pytorch-lightning/pull/12436))
604+
605+
603606
### Removed
604607

605608
- Removed deprecated parameter `method` in `pytorch_lightning.utilities.model_helpers.is_overridden` ([#10507](https://github.com/PyTorchLightning/pytorch-lightning/pull/10507))

pytorch_lightning/callbacks/gpu_stats_monitor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -128,8 +128,8 @@ def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: O
128128

129129
if trainer.strategy.root_device.type != "cuda":
130130
raise MisconfigurationException(
131-
"You are using GPUStatsMonitor but are not running on GPU"
132-
f" since gpus attribute in Trainer is set to {trainer.gpus}."
131+
"You are using GPUStatsMonitor but are not running on GPU."
132+
f" The root device type is {trainer.strategy.root_device.type}."
133133
)
134134

135135
# The logical device IDs for selected devices

pytorch_lightning/trainer/trainer.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2195,6 +2195,10 @@ def scaler(self) -> Optional[Any]:
21952195

21962196
@property
21972197
def gpus(self) -> Optional[Union[List[int], str, int]]:
2198+
rank_zero_deprecation(
2199+
"`Trainer.gpus` was deprecated in v1.6 and will be removed in v1.8."
2200+
" Please use `Trainer.num_devices` or `Trainer.device_ids` to get device information instead."
2201+
)
21982202
return self._accelerator_connector.gpus
21992203

22002204
@property

tests/accelerators/test_accelerator_connector.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -554,7 +554,7 @@ def test_accelerator_gpu_with_devices(devices, plugin):
554554
def test_accelerator_auto_with_devices_gpu():
555555
trainer = Trainer(accelerator="auto", devices=1)
556556
assert isinstance(trainer.accelerator, GPUAccelerator)
557-
assert trainer.gpus == 1
557+
assert trainer.num_devices == 1
558558

559559

560560
def test_validate_accelerator_and_devices():
@@ -946,8 +946,8 @@ def test_devices_auto_choice_cpu(
946946
@mock.patch("torch.cuda.device_count", return_value=2)
947947
def test_devices_auto_choice_gpu(is_gpu_available_mock, device_count_mock):
948948
trainer = Trainer(accelerator="auto", devices="auto")
949+
assert isinstance(trainer.accelerator, GPUAccelerator)
949950
assert trainer.num_devices == 2
950-
assert trainer.gpus == 2
951951

952952

953953
@pytest.mark.parametrize(

tests/deprecated_api/test_remove_1-8.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1106,3 +1106,23 @@ def on_save_checkpoint(self, trainer, pl_module, checkpoint):
11061106

11071107
trainer.callbacks = [TestCallbackSaveHookOverride()]
11081108
trainer.save_checkpoint(tmpdir + "/pathok.ckpt")
1109+
1110+
1111+
@pytest.mark.parametrize(
1112+
"trainer_kwargs",
1113+
[
1114+
{"accelerator": "gpu", "devices": 2},
1115+
{"accelerator": "gpu", "devices": [0, 2]},
1116+
{"accelerator": "gpu", "devices": "2"},
1117+
{"accelerator": "gpu", "devices": "0,"},
1118+
],
1119+
)
1120+
def test_trainer_gpus(monkeypatch, trainer_kwargs):
1121+
monkeypatch.setattr(torch.cuda, "is_available", lambda: True)
1122+
monkeypatch.setattr(torch.cuda, "device_count", lambda: 4)
1123+
trainer = Trainer(**trainer_kwargs)
1124+
with pytest.deprecated_call(
1125+
match="`Trainer.gpus` was deprecated in v1.6 and will be removed in v1.8."
1126+
" Please use `Trainer.num_devices` or `Trainer.device_ids` to get device information instead."
1127+
):
1128+
assert trainer.gpus == trainer_kwargs["devices"]

tests/models/test_gpu.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import tests.helpers.pipelines as tpipes
2424
import tests.helpers.utils as tutils
2525
from pytorch_lightning import Trainer
26+
from pytorch_lightning.accelerators import CPUAccelerator, GPUAccelerator
2627
from pytorch_lightning.plugins.environments import TorchElasticEnvironment
2728
from pytorch_lightning.utilities import device_parser
2829
from pytorch_lightning.utilities.exceptions import MisconfigurationException
@@ -192,10 +193,16 @@ def test_torchelastic_gpu_parsing(mocked_device_count, mocked_is_available, gpus
192193
sanitizing the gpus as only one of the GPUs is visible."""
193194
trainer = Trainer(gpus=gpus)
194195
assert isinstance(trainer._accelerator_connector.cluster_environment, TorchElasticEnvironment)
195-
assert trainer.gpus == gpus
196196
# when use gpu
197197
if device_parser.parse_gpu_ids(gpus) is not None:
198+
assert isinstance(trainer.accelerator, GPUAccelerator)
199+
assert trainer.num_devices == len(gpus) if isinstance(gpus, list) else gpus
198200
assert trainer.device_ids == device_parser.parse_gpu_ids(gpus)
201+
# fall back to cpu
202+
else:
203+
assert isinstance(trainer.accelerator, CPUAccelerator)
204+
assert trainer.num_devices == 1
205+
assert trainer.device_ids == [0]
199206

200207

201208
@RunIf(min_gpus=1)

0 commit comments

Comments
 (0)