Skip to content

Commit 6dc7542

Browse files
committed
Revert "Remove deprecated device attributes from Trainer (#14829)"
This reverts commit 1d3e971.
1 parent d15bd15 commit 6dc7542

File tree

3 files changed

+349
-12
lines changed

3 files changed

+349
-12
lines changed

src/pytorch_lightning/CHANGELOG.md

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -257,16 +257,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
257257

258258
- Removed the deprecated device attributes `Trainer.{devices,gpus,num_gpus,ipus,tpu_cores}` in favor of the accelerator-agnostic `Trainer.num_devices` ([#14829](https://github.com/Lightning-AI/lightning/pull/14829))
259259

260-
261-
- Removed the deprecated `LightningIPUModule` ([#14830](https://github.com/Lightning-AI/lightning/pull/14830))
262-
263-
264-
- Removed the deprecated `Logger.agg_and_log_metrics` hook in favour of `Logger.log_metrics` and the `agg_key_funcs` and `agg_default_func` arguments. ([#14840](https://github.com/Lightning-AI/lightning/pull/14840))
265-
266-
267-
- Removed the deprecated precision plugin checkpoint hooks `PrecisionPlugin.on_load_checkpoint` and `PrecisionPlugin.on_save_checkpoint` ([#14833](https://github.com/Lightning-AI/lightning/pull/14833))
268-
269-
270260
- Removed the deprecated `Trainer.root_gpu` attribute in favor of `Trainer.strategy.root_device` ([#14829](https://github.com/Lightning-AI/lightning/pull/14829))
271261

272262

@@ -275,6 +265,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
275265
- Removed the deprecated callback hooks `Callback.on_init_start` and `Callback.on_init_end` ([#14867](https://github.com/Lightning-AI/lightning/pull/14867))
276266

277267

268+
### Fixed
269+
278270
- Removed the deprecated `Trainer.run_stage` in favor of `Trainer.{fit,validate,test,predict}` ([#14870](https://github.com/Lightning-AI/lightning/pull/14870))
279271

280272

src/pytorch_lightning/trainer/trainer.py

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1710,6 +1710,46 @@ def num_devices(self) -> int:
17101710
"""Number of devices the trainer uses per node."""
17111711
return len(self.device_ids)
17121712

1713+
@property
1714+
def root_gpu(self) -> Optional[int]:
1715+
rank_zero_deprecation(
1716+
"`Trainer.root_gpu` is deprecated in v1.6 and will be removed in v1.8. "
1717+
"Please use `Trainer.strategy.root_device.index` instead."
1718+
)
1719+
return self.strategy.root_device.index if isinstance(self.accelerator, CUDAAccelerator) else None
1720+
1721+
@property
1722+
def tpu_cores(self) -> int:
1723+
rank_zero_deprecation(
1724+
"`Trainer.tpu_cores` is deprecated in v1.6 and will be removed in v1.8. "
1725+
"Please use `Trainer.num_devices` instead."
1726+
)
1727+
return self.num_devices if isinstance(self.accelerator, TPUAccelerator) else 0
1728+
1729+
@property
1730+
def ipus(self) -> int:
1731+
rank_zero_deprecation(
1732+
"`Trainer.ipus` was deprecated in v1.6 and will be removed in v1.8."
1733+
" Please use `Trainer.num_devices` instead."
1734+
)
1735+
return self.num_devices if isinstance(self.accelerator, IPUAccelerator) else 0
1736+
1737+
@property
1738+
def num_gpus(self) -> int:
1739+
rank_zero_deprecation(
1740+
"`Trainer.num_gpus` was deprecated in v1.6 and will be removed in v1.8."
1741+
" Please use `Trainer.num_devices` instead."
1742+
)
1743+
return self.num_devices if isinstance(self.accelerator, CUDAAccelerator) else 0
1744+
1745+
@property
1746+
def devices(self) -> int:
1747+
rank_zero_deprecation(
1748+
"`Trainer.devices` was deprecated in v1.6 and will be removed in v1.8."
1749+
" Please use `Trainer.num_devices` or `Trainer.device_ids` to get device information instead."
1750+
)
1751+
return self.num_devices
1752+
17131753
@property
17141754
def lightning_module(self) -> "pl.LightningModule":
17151755
# TODO: this is actually an optional return
@@ -1752,7 +1792,15 @@ def scaler(self) -> Optional[Any]:
17521792
return getattr(self.precision_plugin, "scaler", None)
17531793

17541794
@property
1755-
def model(self) -> Optional[torch.nn.Module]:
1795+
def gpus(self) -> Optional[Union[List[int], str, int]]:
1796+
rank_zero_deprecation(
1797+
"`Trainer.gpus` was deprecated in v1.6 and will be removed in v1.8."
1798+
" Please use `Trainer.num_devices` or `Trainer.device_ids` to get device information instead."
1799+
)
1800+
return self._accelerator_connector._gpus
1801+
1802+
@property
1803+
def model(self) -> torch.nn.Module:
17561804
"""The LightningModule, but possibly wrapped into DataParallel or DistributedDataParallel.
17571805
17581806
To access the pure LightningModule, use

0 commit comments

Comments
 (0)