Skip to content

Commit cda5dc6

Browse files
Deprecate num_processes,gpus, tpu_cores, and ipus from the Trainer constructor (#11040)
1 parent 3f0f277 commit cda5dc6

22 files changed

+167
-89
lines changed

CHANGELOG.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
5353
- Deprecated `pytorch_lightning.loggers.base.LightningLoggerBase` in favor of `pytorch_lightning.loggers.logger.Logger`, and deprecated `pytorch_lightning.loggers.base` in favor of `pytorch_lightning.loggers.logger` ([#120148](https://github.com/PyTorchLightning/pytorch-lightning/pull/12014))
5454

5555

56-
-
56+
57+
- Deprecated `num_processes`, `gpus`, `tpu_cores,` and `ipus` from the `Trainer` constructor in favor of using the `accelerator` and `devices` arguments ([#11040](https://github.com/PyTorchLightning/pytorch-lightning/pull/11040))
5758

5859

5960
-

docs/source/common/trainer.rst

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -736,6 +736,9 @@ See Also:
736736
gpus
737737
^^^^
738738

739+
.. warning:: ``gpus=x`` has been deprecated in v1.7 and will be removed in v2.0.
740+
Please use ``accelerator='gpu'`` and ``devices=x`` instead.
741+
739742
.. raw:: html
740743

741744
<video width="50%" max-width="400px" controls
@@ -1055,6 +1058,9 @@ Number of GPU nodes for distributed training.
10551058
num_processes
10561059
^^^^^^^^^^^^^
10571060

1061+
.. warning:: ``num_processes=x`` has been deprecated in v1.7 and will be removed in v2.0.
1062+
Please use ``accelerator='cpu'`` and ``devices=x`` instead.
1063+
10581064
.. raw:: html
10591065

10601066
<video width="50%" max-width="400px" controls
@@ -1457,6 +1463,9 @@ track_grad_norm
14571463
tpu_cores
14581464
^^^^^^^^^
14591465

1466+
.. warning:: ``tpu_cores=x`` has been deprecated in v1.7 and will be removed in v2.0.
1467+
Please use ``accelerator='tpu'`` and ``devices=x`` instead.
1468+
14601469
.. raw:: html
14611470

14621471
<video width="50%" max-width="400px" controls

docs/source/ecosystem/asr_nlp_tts.rst

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,8 @@ including the PyTorch Lightning Trainer, customizable from the command line.
204204
.. code-block:: bash
205205
206206
python NeMo/examples/asr/speech_to_text.py --config-name=quartznet_15x5 \
207-
trainer.gpus=4 \
207+
trainer.accelerator=gpu \
208+
trainer.devices=4 \
208209
trainer.max_epochs=128 \
209210
+trainer.precision=16 \
210211
model.train_ds.manifest_filepath=<PATH_TO_DATA>/librispeech-train-all.json \
@@ -433,7 +434,8 @@ Hydra makes every aspect of the NeMo model, including the PyTorch Lightning Trai
433434
model.head.num_fc_layers=2 \
434435
model.dataset.data_dir=/path/to/my/data \
435436
trainer.max_epochs=5 \
436-
trainer.gpus=[0,1]
437+
trainer.accelerator=gpu \
438+
trainer.devices=[0,1]
437439
438440
-----------
439441

@@ -643,7 +645,8 @@ Hydra makes every aspect of the NeMo model, including the PyTorch Lightning Trai
643645
.. code-block:: bash
644646
645647
python NeMo/examples/tts/glow_tts.py \
646-
trainer.gpus=4 \
648+
trainer.accelerator=gpu \
649+
trainer.devices=4 \
647650
trainer.max_epochs=400 \
648651
...
649652
train_dataset=/path/to/train/data \

docs/source/starter/lightning_lite.rst

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -431,6 +431,9 @@ Configure the devices to run on. Can be of type:
431431
gpus
432432
====
433433

434+
.. warning:: ``gpus=x`` has been deprecated in v1.7 and will be removed in v2.0.
435+
Please use ``accelerator='gpu'`` and ``devices=x`` instead.
436+
434437
Shorthand for setting ``devices=X`` and ``accelerator="gpu"``.
435438

436439
.. code-block:: python
@@ -445,6 +448,9 @@ Shorthand for setting ``devices=X`` and ``accelerator="gpu"``.
445448
tpu_cores
446449
=========
447450

451+
.. warning:: ``tpu_cores=x`` has been deprecated in v1.7 and will be removed in v2.0.
452+
Please use ``accelerator='tpu'`` and ``devices=x`` instead.
453+
448454
Shorthand for ``devices=X`` and ``accelerator="tpu"``.
449455

450456
.. code-block:: python

pytorch_lightning/trainer/connectors/accelerator_connector.py

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -414,7 +414,7 @@ def _check_device_config_and_set_final_flags(
414414
self._devices_flag = devices
415415

416416
# TODO: Delete this method when num_processes, gpus, ipus and tpu_cores gets removed
417-
self._map_deprecated_devices_specfic_info_to_accelerator_and_device_flag(
417+
self._map_deprecated_devices_specific_info_to_accelerator_and_device_flag(
418418
devices, num_processes, gpus, ipus, tpu_cores
419419
)
420420

@@ -424,15 +424,36 @@ def _check_device_config_and_set_final_flags(
424424
" `accelerator=('auto'|'tpu'|'gpu'|'ipu'|'cpu'|'hpu)` for the devices mapping"
425425
)
426426

427-
def _map_deprecated_devices_specfic_info_to_accelerator_and_device_flag(
427+
def _map_deprecated_devices_specific_info_to_accelerator_and_device_flag(
428428
self,
429429
devices: Optional[Union[List[int], str, int]],
430430
num_processes: Optional[int],
431431
gpus: Optional[Union[List[int], str, int]],
432432
ipus: Optional[int],
433433
tpu_cores: Optional[Union[List[int], str, int]],
434434
) -> None:
435-
"""Sets the `devices_flag` and `accelerator_flag` based on num_processes, gpus, ipus, tpu_cores."""
435+
"""Emit deprecation warnings for num_processes, gpus, ipus, tpu_cores and set the `devices_flag` and
436+
`accelerator_flag`."""
437+
if num_processes is not None:
438+
rank_zero_deprecation(
439+
f"Setting `Trainer(num_processes={num_processes})` is deprecated in v1.7 and will be removed"
440+
f" in v2.0. Please use `Trainer(accelerator='cpu', devices={num_processes})` instead."
441+
)
442+
if gpus is not None:
443+
rank_zero_deprecation(
444+
f"Setting `Trainer(gpus={gpus!r})` is deprecated in v1.7 and will be removed"
445+
f" in v2.0. Please use `Trainer(accelerator='gpu', devices={gpus!r})` instead."
446+
)
447+
if tpu_cores is not None:
448+
rank_zero_deprecation(
449+
f"Setting `Trainer(tpu_cores={tpu_cores!r})` is deprecated in v1.7 and will be removed"
450+
f" in v2.0. Please use `Trainer(accelerator='tpu', devices={tpu_cores!r})` instead."
451+
)
452+
if ipus is not None:
453+
rank_zero_deprecation(
454+
f"Setting `Trainer(ipus={ipus})` is deprecated in v1.7 and will be removed"
455+
f" in v2.0. Please use `Trainer(accelerator='ipu', devices={ipus})` instead."
456+
)
436457
self._gpus: Optional[Union[List[int], str, int]] = gpus
437458
self._tpu_cores: Optional[Union[List[int], str, int]] = tpu_cores
438459
deprecated_devices_specific_flag = num_processes or gpus or ipus or tpu_cores

pytorch_lightning/trainer/trainer.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -139,12 +139,12 @@ def __init__(
139139
gradient_clip_algorithm: Optional[str] = None,
140140
process_position: int = 0,
141141
num_nodes: int = 1,
142-
num_processes: Optional[int] = None,
142+
num_processes: Optional[int] = None, # TODO: Remove in 2.0
143143
devices: Optional[Union[List[int], str, int]] = None,
144-
gpus: Optional[Union[List[int], str, int]] = None,
144+
gpus: Optional[Union[List[int], str, int]] = None, # TODO: Remove in 2.0
145145
auto_select_gpus: bool = False,
146-
tpu_cores: Optional[Union[List[int], str, int]] = None,
147-
ipus: Optional[int] = None,
146+
tpu_cores: Optional[Union[List[int], str, int]] = None, # TODO: Remove in 2.0
147+
ipus: Optional[int] = None, # TODO: Remove in 2.0
148148
enable_progress_bar: bool = True,
149149
overfit_batches: Union[int, float] = 0.0,
150150
track_grad_norm: Union[int, float, str] = -1,
@@ -275,6 +275,10 @@ def __init__(
275275
gpus: Number of GPUs to train on (int) or which GPUs to train on (list or str) applied per node
276276
Default: ``None``.
277277
278+
.. deprecated:: v1.7
279+
``gpus`` has been deprecated in v1.7 and will be removed in v2.0.
280+
Please use ``accelerator='gpu'`` and ``devices=x`` instead.
281+
278282
gradient_clip_val: The value at which to clip gradients. Passing ``gradient_clip_val=None`` disables
279283
gradient clipping. If using Automatic Mixed Precision (AMP), the gradients will be unscaled before.
280284
Default: ``None``.
@@ -351,6 +355,10 @@ def __init__(
351355
num_processes: Number of processes for distributed training with ``accelerator="cpu"``.
352356
Default: ``1``.
353357
358+
.. deprecated:: v1.7
359+
``num_processes`` has been deprecated in v1.7 and will be removed in v2.0.
360+
Please use ``accelerator='cpu'`` and ``devices=x`` instead.
361+
354362
num_sanity_val_steps: Sanity check runs n validation batches before starting the training routine.
355363
Set it to `-1` to run all batches in all validation dataloaders.
356364
Default: ``2``.
@@ -381,9 +389,17 @@ def __init__(
381389
tpu_cores: How many TPU cores to train on (1 or 8) / Single TPU to train on (1)
382390
Default: ``None``.
383391
392+
.. deprecated:: v1.7
393+
``tpu_cores`` has been deprecated in v1.7 and will be removed in v2.0.
394+
Please use ``accelerator='tpu'`` and ``devices=x`` instead.
395+
384396
ipus: How many IPUs to train on.
385397
Default: ``None``.
386398
399+
.. deprecated:: v1.7
400+
``ipus`` has been deprecated in v1.7 and will be removed in v2.0.
401+
Please use ``accelerator='ipu'`` and ``devices=x`` instead.
402+
387403
track_grad_norm: -1 no tracking. Otherwise tracks that p-norm. May be set to 'inf' infinity-norm. If using
388404
Automatic Mixed Precision (AMP), the gradients will be unscaled before logging them.
389405
Default: ``-1``.

tests/accelerators/test_accelerator_connector.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,6 @@
4949
from tests.helpers.runif import RunIf
5050

5151

52-
# TODO: please modify/sunset any test that has accelerator=ddp/ddp2/ddp_cpu/ddp_spawn @daniellepintz
5352
def test_accelerator_choice_cpu(tmpdir):
5453
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True)
5554
assert isinstance(trainer.accelerator, CPUAccelerator)

tests/accelerators/test_ipu.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -101,10 +101,7 @@ def test_epoch_end(self, outputs) -> None:
101101
@mock.patch("pytorch_lightning.accelerators.ipu.IPUAccelerator.is_available", return_value=True)
102102
def test_fail_if_no_ipus(mock_ipu_acc_avail, tmpdir):
103103
with pytest.raises(MisconfigurationException, match="IPU Accelerator requires IPU devices to run"):
104-
Trainer(default_root_dir=tmpdir, ipus=1)
105-
106-
with pytest.raises(MisconfigurationException, match="IPU Accelerator requires IPU devices to run"):
107-
Trainer(default_root_dir=tmpdir, ipus=1, accelerator="ipu")
104+
Trainer(default_root_dir=tmpdir, accelerator="ipu", devices=1)
108105

109106

110107
@RunIf(ipu=True)
@@ -398,7 +395,8 @@ def test_manual_poptorch_opts(tmpdir):
398395

399396
trainer = Trainer(
400397
default_root_dir=tmpdir,
401-
ipus=2,
398+
accelerator="ipu",
399+
devices=2,
402400
fast_dev_run=True,
403401
strategy=IPUStrategy(inference_opts=inference_opts, training_opts=training_opts),
404402
)
@@ -552,13 +550,13 @@ def test_precision_plugin(tmpdir):
552550

553551
@RunIf(ipu=True)
554552
def test_accelerator_ipu():
555-
trainer = Trainer(accelerator="ipu", ipus=1)
553+
trainer = Trainer(accelerator="ipu", devices=1)
556554
assert isinstance(trainer.accelerator, IPUAccelerator)
557555

558556
trainer = Trainer(accelerator="ipu")
559557
assert isinstance(trainer.accelerator, IPUAccelerator)
560558

561-
trainer = Trainer(accelerator="auto", ipus=8)
559+
trainer = Trainer(accelerator="auto", devices=8)
562560
assert isinstance(trainer.accelerator, IPUAccelerator)
563561

564562

@@ -592,8 +590,8 @@ def test_accelerator_ipu_with_ipus_priority():
592590

593591
@RunIf(ipu=True)
594592
def test_set_devices_if_none_ipu():
595-
596-
trainer = Trainer(accelerator="ipu", ipus=8)
593+
with pytest.deprecated_call(match=r"is deprecated in v1.7 and will be removed in v2.0."):
594+
trainer = Trainer(accelerator="ipu", ipus=8)
597595
assert trainer.num_devices == 8
598596

599597

tests/accelerators/test_tpu.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,8 @@ def test_accelerator_tpu_with_tpu_cores_priority():
118118

119119
@RunIf(tpu=True)
120120
def test_set_devices_if_none_tpu():
121-
trainer = Trainer(accelerator="tpu", tpu_cores=8)
121+
with pytest.deprecated_call(match=r"is deprecated in v1.7 and will be removed in v2.0."):
122+
trainer = Trainer(accelerator="tpu", tpu_cores=8)
122123
assert isinstance(trainer.accelerator, TPUAccelerator)
123124
assert trainer.num_devices == 8
124125

tests/benchmarks/test_basic_parity.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,8 @@ def lightning_loop(cls_model, idx, device_type: str = "cuda", num_epochs=10):
159159
enable_progress_bar=False,
160160
enable_model_summary=False,
161161
enable_checkpointing=False,
162-
gpus=1 if device_type == "cuda" else 0,
162+
accelerator="gpu" if device_type == "cuda" else "cpu",
163+
devices=1,
163164
logger=False,
164165
replace_sampler_ddp=False,
165166
)

tests/benchmarks/test_sharded_parity.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -137,15 +137,24 @@ def plugin_parity_test(
137137
ddp_model = model_cls()
138138
use_cuda = gpus > 0
139139

140-
trainer = Trainer(fast_dev_run=True, max_epochs=1, gpus=gpus, precision=precision, strategy="ddp_spawn")
140+
trainer = Trainer(
141+
fast_dev_run=True, max_epochs=1, accelerator="gpu", devices=gpus, precision=precision, strategy="ddp_spawn"
142+
)
141143

142144
max_memory_ddp, ddp_time = record_ddp_fit_model_stats(trainer=trainer, model=ddp_model, use_cuda=use_cuda)
143145

144146
# Reset and train Custom DDP
145147
seed_everything(seed)
146148
custom_plugin_model = model_cls()
147149

148-
trainer = Trainer(fast_dev_run=True, max_epochs=1, gpus=gpus, precision=precision, strategy="ddp_sharded_spawn")
150+
trainer = Trainer(
151+
fast_dev_run=True,
152+
max_epochs=1,
153+
accelerator="gpu",
154+
devices=gpus,
155+
precision=precision,
156+
strategy="ddp_sharded_spawn",
157+
)
149158
assert isinstance(trainer.strategy, DDPSpawnShardedStrategy)
150159

151160
max_memory_custom, custom_model_time = record_ddp_fit_model_stats(

tests/deprecated_api/test_remove_2-0.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,43 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
"""Test deprecated functionality which will be removed in v2.0."""
14+
"""Test deprecated functionality which will be removed in v2.0.0."""
15+
from unittest import mock
16+
1517
import pytest
1618

19+
import pytorch_lightning
1720
from pytorch_lightning import Trainer
1821
from tests.callbacks.test_callbacks import OldStatefulCallback
1922
from tests.helpers import BoringModel
2023

2124

25+
def test_v2_0_0_deprecated_num_processes():
26+
with pytest.deprecated_call(match=r"is deprecated in v1.7 and will be removed in v2.0."):
27+
_ = Trainer(num_processes=2)
28+
29+
30+
@mock.patch("torch.cuda.is_available", return_value=True)
31+
@mock.patch("torch.cuda.device_count", return_value=2)
32+
def test_v2_0_0_deprecated_gpus(*_):
33+
with pytest.deprecated_call(match=r"is deprecated in v1.7 and will be removed in v2.0."):
34+
_ = Trainer(gpus=0)
35+
36+
37+
@mock.patch("pytorch_lightning.accelerators.tpu.TPUAccelerator.is_available", return_value=True)
38+
@mock.patch("pytorch_lightning.accelerators.tpu.TPUAccelerator.parse_devices", return_value=8)
39+
def test_v2_0_0_deprecated_tpu_cores(*_):
40+
with pytest.deprecated_call(match=r"is deprecated in v1.7 and will be removed in v2.0."):
41+
_ = Trainer(tpu_cores=8)
42+
43+
44+
@mock.patch("pytorch_lightning.accelerators.ipu.IPUAccelerator.is_available", return_value=True)
45+
def test_v2_0_0_deprecated_ipus(_, monkeypatch):
46+
monkeypatch.setattr(pytorch_lightning.strategies.ipu, "_IPU_AVAILABLE", True)
47+
with pytest.deprecated_call(match=r"is deprecated in v1.7 and will be removed in v2.0."):
48+
_ = Trainer(ipus=4)
49+
50+
2251
def test_v2_0_resume_from_checkpoint_trainer_constructor(tmpdir):
2352
# test resume_from_checkpoint still works until v2.0 deprecation
2453
model = BoringModel()

tests/models/test_gpu.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,8 @@ def test_parse_gpu_returns_none_when_no_devices_are_available(mocked_device_coun
191191
def test_torchelastic_gpu_parsing(mocked_device_count, mocked_is_available, gpus):
192192
"""Ensure when using torchelastic and nproc_per_node is set to the default of 1 per GPU device That we omit
193193
sanitizing the gpus as only one of the GPUs is visible."""
194-
trainer = Trainer(gpus=gpus)
194+
with pytest.deprecated_call(match=r"is deprecated in v1.7 and will be removed in v2.0."):
195+
trainer = Trainer(gpus=gpus)
195196
assert isinstance(trainer._accelerator_connector.cluster_environment, TorchElasticEnvironment)
196197
# when use gpu
197198
if device_parser.parse_gpu_ids(gpus) is not None:

tests/models/test_hooks.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -438,10 +438,13 @@ def _predict_batch(trainer, model, batches):
438438
[
439439
{},
440440
# these precision plugins modify the optimization flow, so testing them explicitly
441-
pytest.param(dict(gpus=1, precision=16, amp_backend="native"), marks=RunIf(min_gpus=1)),
442-
pytest.param(dict(gpus=1, precision=16, amp_backend="apex"), marks=RunIf(min_gpus=1, amp_apex=True)),
441+
pytest.param(dict(accelerator="gpu", devices=1, precision=16, amp_backend="native"), marks=RunIf(min_gpus=1)),
443442
pytest.param(
444-
dict(gpus=1, precision=16, strategy="deepspeed"), marks=RunIf(min_gpus=1, standalone=True, deepspeed=True)
443+
dict(accelerator="gpu", devices=1, precision=16, amp_backend="apex"), marks=RunIf(min_gpus=1, amp_apex=True)
444+
),
445+
pytest.param(
446+
dict(accelerator="gpu", devices=1, precision=16, strategy="deepspeed"),
447+
marks=RunIf(min_gpus=1, standalone=True, deepspeed=True),
445448
),
446449
],
447450
)
@@ -496,7 +499,7 @@ def training_step(self, batch, batch_idx):
496499
}
497500
if kwargs.get("amp_backend") == "native" or kwargs.get("amp_backend") == "apex":
498501
saved_ckpt[trainer.precision_plugin.__class__.__qualname__] = ANY
499-
device = torch.device("cuda:0" if "gpus" in kwargs else "cpu")
502+
device = torch.device("cuda:0" if "accelerator" in kwargs and kwargs["accelerator"] == "gpu" else "cpu")
500503
expected = [
501504
dict(name="Callback.on_init_start", args=(trainer,)),
502505
dict(name="Callback.on_init_end", args=(trainer,)),

tests/models/test_tpu.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -305,7 +305,8 @@ def test_tpu_cores_with_argparse(cli_args, expected):
305305

306306
for k, v in expected.items():
307307
assert getattr(args, k) == v
308-
assert Trainer.from_argparse_args(args)
308+
with pytest.deprecated_call(match=r"is deprecated in v1.7 and will be removed in v2.0."):
309+
assert Trainer.from_argparse_args(args)
309310

310311

311312
@RunIf(tpu=True)

0 commit comments

Comments
 (0)