Skip to content

Commit 82df087

Browse files
committed
Generalize internal checks for Accelerator in Trainer
1 parent aeb0b55 commit 82df087

File tree

1 file changed

+8
-10
lines changed

1 file changed

+8
-10
lines changed

pytorch_lightning/trainer/trainer.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from torch.optim import Optimizer
2828

2929
import pytorch_lightning as pl
30-
from pytorch_lightning.accelerators import Accelerator, IPUAccelerator
30+
from pytorch_lightning.accelerators import Accelerator, GPUAccelerator, IPUAccelerator, TPUAccelerator
3131
from pytorch_lightning.callbacks import Callback, EarlyStopping, ModelCheckpoint, ProgressBarBase
3232
from pytorch_lightning.callbacks.prediction_writer import BasePredictionWriter
3333
from pytorch_lightning.core.datamodule import LightningDataModule
@@ -1650,33 +1650,31 @@ def __setup_profiler(self) -> None:
16501650
self.profiler.setup(stage=self.state.fn._setup_fn, local_rank=local_rank, log_dir=self.log_dir)
16511651

16521652
def _log_device_info(self) -> None:
1653-
rank_zero_info(f"GPU available: {torch.cuda.is_available()}, used: {self._device_type == _AcceleratorType.GPU}")
1653+
rank_zero_info(
1654+
f"GPU available: {torch.cuda.is_available()}, used: {isinstance(self.accelerator, GPUAccelerator)}"
1655+
)
16541656

16551657
num_tpu_cores = (
1656-
self.tpu_cores if self.tpu_cores is not None and self._device_type == _AcceleratorType.TPU else 0
1658+
self.tpu_cores if self.tpu_cores is not None and isinstance(self.accelerator, TPUAccelerator) else 0
16571659
)
16581660
rank_zero_info(f"TPU available: {_TPU_AVAILABLE}, using: {num_tpu_cores} TPU cores")
16591661

16601662
num_ipus = self.ipus if self.ipus is not None else 0
16611663
rank_zero_info(f"IPU available: {_IPU_AVAILABLE}, using: {num_ipus} IPUs")
16621664

1663-
if torch.cuda.is_available() and self._device_type != _AcceleratorType.GPU:
1665+
if torch.cuda.is_available() and isinstance(self.accelerator, GPUAccelerator):
16641666
rank_zero_warn(
16651667
"GPU available but not used. Set the gpus flag in your trainer `Trainer(gpus=1)` or script `--gpus=1`.",
16661668
category=PossibleUserWarning,
16671669
)
16681670

1669-
if _TPU_AVAILABLE and self._device_type != _AcceleratorType.TPU:
1671+
if _TPU_AVAILABLE and isinstance(self.accelerator, TPUAccelerator):
16701672
rank_zero_warn(
16711673
"TPU available but not used. Set the `tpu_cores` flag in your trainer"
16721674
" `Trainer(tpu_cores=8)` or script `--tpu_cores=8`."
16731675
)
16741676

1675-
if (
1676-
_IPU_AVAILABLE
1677-
and self._device_type != _AcceleratorType.IPU
1678-
and not isinstance(self.accelerator, IPUAccelerator)
1679-
):
1677+
if _IPU_AVAILABLE and not isinstance(self.accelerator, IPUAccelerator):
16801678
rank_zero_warn(
16811679
"IPU available but not used. Set the `ipus` flag in your trainer"
16821680
" `Trainer(ipus=8)` or script `--ipus=8`."

0 commit comments

Comments
 (0)