diff --git a/CHANGELOG.md b/CHANGELOG.md index 0126950f2697b..93f0bdb72f634 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -606,6 +606,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Deprecated `Trainer.gpus` in favor of `Trainer.device_ids` or `Trainer.num_devices` ([#12436](https://github.com/PyTorchLightning/pytorch-lightning/pull/12436)) +- Deprecated `Trainer.tpu_cores` in favor of `Trainer.num_devices` ([#12437](https://github.com/PyTorchLightning/pytorch-lightning/pull/12437)) + + ### Removed - Removed deprecated parameter `method` in `pytorch_lightning.utilities.model_helpers.is_overridden` ([#10507](https://github.com/PyTorchLightning/pytorch-lightning/pull/10507)) @@ -817,6 +820,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Removed `AcceleratorConnector.parallel_devices` property ([#12075](https://github.com/PyTorchLightning/pytorch-lightning/pull/12075)) +- Removed `AcceleratorConnector.tpu_cores` property ([#12437](https://github.com/PyTorchLightning/pytorch-lightning/pull/12437)) + + ### Fixed - Fixed an issue where `ModelCheckpoint` could delete older checkpoints when `dirpath` has changed during resumed training ([#12045](https://github.com/PyTorchLightning/pytorch-lightning/pull/12045)) diff --git a/pytorch_lightning/callbacks/xla_stats_monitor.py b/pytorch_lightning/callbacks/xla_stats_monitor.py index ebc6ca9d72357..c7fe59a59d515 100644 --- a/pytorch_lightning/callbacks/xla_stats_monitor.py +++ b/pytorch_lightning/callbacks/xla_stats_monitor.py @@ -73,10 +73,10 @@ def on_train_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") if not trainer.loggers: raise MisconfigurationException("Cannot use XLAStatsMonitor callback with Trainer that has no logger.") - if isinstance(trainer.accelerator, TPUAccelerator): + if not isinstance(trainer.accelerator, TPUAccelerator): raise MisconfigurationException( - "You are using XLAStatsMonitor but are not running on TPU" - f" since `tpu_cores` attribute in Trainer is set to {trainer.tpu_cores}." + "You are using XLAStatsMonitor but are not running on TPU." + f" The accelerator is set to {trainer.accelerator.__class__.__name__}." ) device = trainer.strategy.root_device diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index e5168dfef4c83..0cb5f59f14dc4 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1811,9 +1811,7 @@ def _log_device_info(self) -> None: f"GPU available: {torch.cuda.is_available()}, used: {isinstance(self.accelerator, GPUAccelerator)}" ) - num_tpu_cores = ( - self.tpu_cores if self.tpu_cores is not None and isinstance(self.accelerator, TPUAccelerator) else 0 - ) + num_tpu_cores = self.num_devices if isinstance(self.accelerator, TPUAccelerator) else 0 rank_zero_info(f"TPU available: {_TPU_AVAILABLE}, using: {num_tpu_cores} TPU cores") num_ipus = self.num_devices if isinstance(self.accelerator, IPUAccelerator) else 0 @@ -2068,7 +2066,11 @@ def num_nodes(self) -> int: @property def device_ids(self) -> List[int]: """List of device indexes per node.""" - devices = getattr(self.strategy, "parallel_devices", [self.strategy.root_device]) + devices = ( + self.strategy.parallel_devices + if isinstance(self.strategy, ParallelStrategy) + else [self.strategy.root_device] + ) device_ids = [] for idx, device in enumerate(devices): if isinstance(device, torch.device): @@ -2100,7 +2102,11 @@ def root_gpu(self) -> Optional[int]: @property def tpu_cores(self) -> int: - return self._accelerator_connector.tpu_cores + rank_zero_deprecation( + "`Trainer.tpu_cores` is deprecated in v1.6 and will be removed in v1.8. " + "Please use `Trainer.num_devices` instead." + ) + return self.num_devices if isinstance(self.accelerator, TPUAccelerator) else 0 @property def ipus(self) -> int: diff --git a/tests/accelerators/test_tpu.py b/tests/accelerators/test_tpu.py index 7e522ffd170cd..c1cb1e2f369a9 100644 --- a/tests/accelerators/test_tpu.py +++ b/tests/accelerators/test_tpu.py @@ -94,7 +94,6 @@ def test_accelerator_cpu_with_tpu_cores_flag(): @RunIf(tpu=True) -@pl_multi_process_test @pytest.mark.parametrize(["accelerator", "devices"], [("auto", 8), ("auto", "auto"), ("tpu", None)]) def test_accelerator_tpu(accelerator, devices): assert TPUAccelerator.is_available() @@ -103,7 +102,6 @@ def test_accelerator_tpu(accelerator, devices): assert isinstance(trainer.accelerator, TPUAccelerator) assert isinstance(trainer.strategy, TPUSpawnStrategy) assert trainer.num_devices == 8 - assert trainer.tpu_cores == 8 @RunIf(tpu=True) @@ -114,13 +112,14 @@ def test_accelerator_tpu_with_tpu_cores_priority(): with pytest.warns(UserWarning, match="The flag `devices=1` will be ignored,"): trainer = Trainer(accelerator="tpu", devices=1, tpu_cores=tpu_cores) - assert trainer.tpu_cores == tpu_cores + assert isinstance(trainer.accelerator, TPUAccelerator) + assert trainer.num_devices == tpu_cores @RunIf(tpu=True) -@pl_multi_process_test def test_set_devices_if_none_tpu(): trainer = Trainer(accelerator="tpu", tpu_cores=8) + assert isinstance(trainer.accelerator, TPUAccelerator) assert trainer.num_devices == 8 diff --git a/tests/deprecated_api/test_remove_1-7.py b/tests/deprecated_api/test_remove_1-7.py index 8cfaa843a9f10..548e45683c13f 100644 --- a/tests/deprecated_api/test_remove_1-7.py +++ b/tests/deprecated_api/test_remove_1-7.py @@ -20,6 +20,7 @@ import pytest import torch +import pytorch_lightning from pytorch_lightning import Callback, LightningDataModule, Trainer from pytorch_lightning.callbacks.gpu_stats_monitor import GPUStatsMonitor from pytorch_lightning.callbacks.lr_monitor import LearningRateMonitor @@ -35,6 +36,7 @@ TorchElasticEnvironment, ) from pytorch_lightning.strategies import SingleDeviceStrategy +from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.deprecated_api import _soft_unimport_module from tests.helpers import BoringModel from tests.helpers.datamodules import MNISTDataModule @@ -393,8 +395,8 @@ def test_v1_7_0_deprecate_gpu_stats_monitor(tmpdir): _ = GPUStatsMonitor() -@RunIf(tpu=True) -def test_v1_7_0_deprecate_xla_stats_monitor(tmpdir): +def test_v1_7_0_deprecate_xla_stats_monitor(monkeypatch): + monkeypatch.setattr(pytorch_lightning.callbacks.xla_stats_monitor, "_TPU_AVAILABLE", True) with pytest.deprecated_call(match="The `XLAStatsMonitor` callback was deprecated in v1.5"): _ = XLAStatsMonitor() @@ -516,3 +518,17 @@ def post_dispatch(self, trainer): with pytest.deprecated_call(match=escape("`CustomPlugin.post_dispatch()` has been deprecated in v1.6")): CustomPlugin(torch.device("cpu")) + + +def test_xla_stats_monitor_tpu_not_used(monkeypatch): + monkeypatch.setattr(pytorch_lightning.callbacks.xla_stats_monitor, "_TPU_AVAILABLE", True) + with pytest.deprecated_call(match="The `XLAStatsMonitor` callback was deprecated in v1.5"): + xla_stats = XLAStatsMonitor() + + trainer = Trainer(accelerator="cpu", callbacks=[xla_stats]) + model = BoringModel() + with pytest.raises( + MisconfigurationException, + match="You are using XLAStatsMonitor but are not running on TPU. The accelerator is set to CPUAccelerator.", + ): + trainer.fit(model) diff --git a/tests/deprecated_api/test_remove_1-8.py b/tests/deprecated_api/test_remove_1-8.py index a7cabf1794a8f..ae2e0104ab3b2 100644 --- a/tests/deprecated_api/test_remove_1-8.py +++ b/tests/deprecated_api/test_remove_1-8.py @@ -1126,3 +1126,14 @@ def test_trainer_gpus(monkeypatch, trainer_kwargs): " Please use `Trainer.num_devices` or `Trainer.device_ids` to get device information instead." ): assert trainer.gpus == trainer_kwargs["devices"] + + +def test_trainer_tpu_cores(monkeypatch): + monkeypatch.setattr(pytorch_lightning.accelerators.tpu.TPUAccelerator, "is_available", lambda: True) + monkeypatch.setattr(pytorch_lightning.accelerators.tpu.TPUAccelerator, "parse_devices", lambda: 8) + trainer = Trainer(accelerator="TPU", devices=8) + with pytest.deprecated_call( + match="`Trainer.tpu_cores` is deprecated in v1.6 and will be removed in v1.8. " + "Please use `Trainer.num_devices` instead." + ): + trainer.tpu_cores == 8