Skip to content

Commit 534aa17

Browse files
authored
Remove AcceleratorConnector.tpu_cores and Deprecate Trainer.tpu_cores (#12437)
1 parent 9cd6d0f commit 534aa17

File tree

6 files changed

+52
-14
lines changed

6 files changed

+52
-14
lines changed

CHANGELOG.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -612,6 +612,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
612612
- Deprecated `Trainer.gpus` in favor of `Trainer.device_ids` or `Trainer.num_devices` ([#12436](https://github.com/PyTorchLightning/pytorch-lightning/pull/12436))
613613

614614

615+
- Deprecated `Trainer.tpu_cores` in favor of `Trainer.num_devices` ([#12437](https://github.com/PyTorchLightning/pytorch-lightning/pull/12437))
616+
617+
615618
### Removed
616619

617620
- Removed deprecated parameter `method` in `pytorch_lightning.utilities.model_helpers.is_overridden` ([#10507](https://github.com/PyTorchLightning/pytorch-lightning/pull/10507))
@@ -823,6 +826,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
823826
- Removed `AcceleratorConnector.parallel_devices` property ([#12075](https://github.com/PyTorchLightning/pytorch-lightning/pull/12075))
824827

825828

829+
- Removed `AcceleratorConnector.tpu_cores` property ([#12437](https://github.com/PyTorchLightning/pytorch-lightning/pull/12437))
830+
831+
826832
### Fixed
827833

828834
- Fixed an issue where `ModelCheckpoint` could delete last checkpoint from the old directory when `dirpath` has changed during resumed training ([#12225](https://github.com/PyTorchLightning/pytorch-lightning/pull/12225))

pytorch_lightning/callbacks/xla_stats_monitor.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -73,10 +73,10 @@ def on_train_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule")
7373
if not trainer.loggers:
7474
raise MisconfigurationException("Cannot use XLAStatsMonitor callback with Trainer that has no logger.")
7575

76-
if isinstance(trainer.accelerator, TPUAccelerator):
76+
if not isinstance(trainer.accelerator, TPUAccelerator):
7777
raise MisconfigurationException(
78-
"You are using XLAStatsMonitor but are not running on TPU"
79-
f" since `tpu_cores` attribute in Trainer is set to {trainer.tpu_cores}."
78+
"You are using XLAStatsMonitor but are not running on TPU."
79+
f" The accelerator is set to {trainer.accelerator.__class__.__name__}."
8080
)
8181

8282
device = trainer.strategy.root_device

pytorch_lightning/trainer/trainer.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1812,9 +1812,7 @@ def _log_device_info(self) -> None:
18121812
f"GPU available: {torch.cuda.is_available()}, used: {isinstance(self.accelerator, GPUAccelerator)}"
18131813
)
18141814

1815-
num_tpu_cores = (
1816-
self.tpu_cores if self.tpu_cores is not None and isinstance(self.accelerator, TPUAccelerator) else 0
1817-
)
1815+
num_tpu_cores = self.num_devices if isinstance(self.accelerator, TPUAccelerator) else 0
18181816
rank_zero_info(f"TPU available: {_TPU_AVAILABLE}, using: {num_tpu_cores} TPU cores")
18191817

18201818
num_ipus = self.num_devices if isinstance(self.accelerator, IPUAccelerator) else 0
@@ -2069,7 +2067,11 @@ def num_nodes(self) -> int:
20692067
@property
20702068
def device_ids(self) -> List[int]:
20712069
"""List of device indexes per node."""
2072-
devices = getattr(self.strategy, "parallel_devices", [self.strategy.root_device])
2070+
devices = (
2071+
self.strategy.parallel_devices
2072+
if isinstance(self.strategy, ParallelStrategy)
2073+
else [self.strategy.root_device]
2074+
)
20732075
device_ids = []
20742076
for idx, device in enumerate(devices):
20752077
if isinstance(device, torch.device):
@@ -2101,7 +2103,11 @@ def root_gpu(self) -> Optional[int]:
21012103

21022104
@property
21032105
def tpu_cores(self) -> int:
2104-
return self._accelerator_connector.tpu_cores
2106+
rank_zero_deprecation(
2107+
"`Trainer.tpu_cores` is deprecated in v1.6 and will be removed in v1.8. "
2108+
"Please use `Trainer.num_devices` instead."
2109+
)
2110+
return self.num_devices if isinstance(self.accelerator, TPUAccelerator) else 0
21052111

21062112
@property
21072113
def ipus(self) -> int:

tests/accelerators/test_tpu.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,6 @@ def test_accelerator_cpu_with_tpu_cores_flag():
9494

9595

9696
@RunIf(tpu=True)
97-
@pl_multi_process_test
9897
@pytest.mark.parametrize(["accelerator", "devices"], [("auto", 8), ("auto", "auto"), ("tpu", None)])
9998
def test_accelerator_tpu(accelerator, devices):
10099
assert TPUAccelerator.is_available()
@@ -103,7 +102,6 @@ def test_accelerator_tpu(accelerator, devices):
103102
assert isinstance(trainer.accelerator, TPUAccelerator)
104103
assert isinstance(trainer.strategy, TPUSpawnStrategy)
105104
assert trainer.num_devices == 8
106-
assert trainer.tpu_cores == 8
107105

108106

109107
@RunIf(tpu=True)
@@ -114,13 +112,14 @@ def test_accelerator_tpu_with_tpu_cores_priority():
114112
with pytest.warns(UserWarning, match="The flag `devices=1` will be ignored,"):
115113
trainer = Trainer(accelerator="tpu", devices=1, tpu_cores=tpu_cores)
116114

117-
assert trainer.tpu_cores == tpu_cores
115+
assert isinstance(trainer.accelerator, TPUAccelerator)
116+
assert trainer.num_devices == tpu_cores
118117

119118

120119
@RunIf(tpu=True)
121-
@pl_multi_process_test
122120
def test_set_devices_if_none_tpu():
123121
trainer = Trainer(accelerator="tpu", tpu_cores=8)
122+
assert isinstance(trainer.accelerator, TPUAccelerator)
124123
assert trainer.num_devices == 8
125124

126125

tests/deprecated_api/test_remove_1-7.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import pytest
2121
import torch
2222

23+
import pytorch_lightning
2324
from pytorch_lightning import Callback, LightningDataModule, Trainer
2425
from pytorch_lightning.callbacks.gpu_stats_monitor import GPUStatsMonitor
2526
from pytorch_lightning.callbacks.lr_monitor import LearningRateMonitor
@@ -35,6 +36,7 @@
3536
TorchElasticEnvironment,
3637
)
3738
from pytorch_lightning.strategies import SingleDeviceStrategy
39+
from pytorch_lightning.utilities.exceptions import MisconfigurationException
3840
from tests.deprecated_api import _soft_unimport_module
3941
from tests.helpers import BoringModel
4042
from tests.helpers.datamodules import MNISTDataModule
@@ -393,8 +395,8 @@ def test_v1_7_0_deprecate_gpu_stats_monitor(tmpdir):
393395
_ = GPUStatsMonitor()
394396

395397

396-
@RunIf(tpu=True)
397-
def test_v1_7_0_deprecate_xla_stats_monitor(tmpdir):
398+
def test_v1_7_0_deprecate_xla_stats_monitor(monkeypatch):
399+
monkeypatch.setattr(pytorch_lightning.callbacks.xla_stats_monitor, "_TPU_AVAILABLE", True)
398400
with pytest.deprecated_call(match="The `XLAStatsMonitor` callback was deprecated in v1.5"):
399401
_ = XLAStatsMonitor()
400402

@@ -516,3 +518,17 @@ def post_dispatch(self, trainer):
516518

517519
with pytest.deprecated_call(match=escape("`CustomPlugin.post_dispatch()` has been deprecated in v1.6")):
518520
CustomPlugin(torch.device("cpu"))
521+
522+
523+
def test_xla_stats_monitor_tpu_not_used(monkeypatch):
524+
monkeypatch.setattr(pytorch_lightning.callbacks.xla_stats_monitor, "_TPU_AVAILABLE", True)
525+
with pytest.deprecated_call(match="The `XLAStatsMonitor` callback was deprecated in v1.5"):
526+
xla_stats = XLAStatsMonitor()
527+
528+
trainer = Trainer(accelerator="cpu", callbacks=[xla_stats])
529+
model = BoringModel()
530+
with pytest.raises(
531+
MisconfigurationException,
532+
match="You are using XLAStatsMonitor but are not running on TPU. The accelerator is set to CPUAccelerator.",
533+
):
534+
trainer.fit(model)

tests/deprecated_api/test_remove_1-8.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1136,3 +1136,14 @@ def test_trainer_gpus(monkeypatch, trainer_kwargs):
11361136
" Please use `Trainer.num_devices` or `Trainer.device_ids` to get device information instead."
11371137
):
11381138
assert trainer.gpus == trainer_kwargs["devices"]
1139+
1140+
1141+
def test_trainer_tpu_cores(monkeypatch):
1142+
monkeypatch.setattr(pytorch_lightning.accelerators.tpu.TPUAccelerator, "is_available", lambda: True)
1143+
monkeypatch.setattr(pytorch_lightning.accelerators.tpu.TPUAccelerator, "parse_devices", lambda: 8)
1144+
trainer = Trainer(accelerator="TPU", devices=8)
1145+
with pytest.deprecated_call(
1146+
match="`Trainer.tpu_cores` is deprecated in v1.6 and will be removed in v1.8. "
1147+
"Please use `Trainer.num_devices` instead."
1148+
):
1149+
trainer.tpu_cores == 8

0 commit comments

Comments
 (0)