Skip to content

Commit 186929d

Browse files
committed
Generalize internal checks for Accelerator in Trainer
1 parent a4083df commit 186929d

File tree

1 file changed

+8
-10
lines changed

1 file changed

+8
-10
lines changed

pytorch_lightning/trainer/trainer.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from torch.optim import Optimizer
2828

2929
import pytorch_lightning as pl
30-
from pytorch_lightning.accelerators import Accelerator, IPUAccelerator
30+
from pytorch_lightning.accelerators import Accelerator, GPUAccelerator, IPUAccelerator, TPUAccelerator
3131
from pytorch_lightning.callbacks import Callback, EarlyStopping, ModelCheckpoint, ProgressBarBase
3232
from pytorch_lightning.callbacks.prediction_writer import BasePredictionWriter
3333
from pytorch_lightning.core.datamodule import LightningDataModule
@@ -1626,33 +1626,31 @@ def __setup_profiler(self) -> None:
16261626
self.profiler.setup(stage=self.state.fn._setup_fn, local_rank=local_rank, log_dir=self.log_dir)
16271627

16281628
def _log_device_info(self) -> None:
1629-
rank_zero_info(f"GPU available: {torch.cuda.is_available()}, used: {self._device_type == _AcceleratorType.GPU}")
1629+
rank_zero_info(
1630+
f"GPU available: {torch.cuda.is_available()}, used: {isinstance(self.accelerator, GPUAccelerator)}"
1631+
)
16301632

16311633
num_tpu_cores = (
1632-
self.tpu_cores if self.tpu_cores is not None and self._device_type == _AcceleratorType.TPU else 0
1634+
self.tpu_cores if self.tpu_cores is not None and isinstance(self.accelerator, TPUAccelerator) else 0
16331635
)
16341636
rank_zero_info(f"TPU available: {_TPU_AVAILABLE}, using: {num_tpu_cores} TPU cores")
16351637

16361638
num_ipus = self.ipus if self.ipus is not None else 0
16371639
rank_zero_info(f"IPU available: {_IPU_AVAILABLE}, using: {num_ipus} IPUs")
16381640

1639-
if torch.cuda.is_available() and self._device_type != _AcceleratorType.GPU:
1641+
if torch.cuda.is_available() and isinstance(self.accelerator, GPUAccelerator):
16401642
rank_zero_warn(
16411643
"GPU available but not used. Set the gpus flag in your trainer `Trainer(gpus=1)` or script `--gpus=1`.",
16421644
category=PossibleUserWarning,
16431645
)
16441646

1645-
if _TPU_AVAILABLE and self._device_type != _AcceleratorType.TPU:
1647+
if _TPU_AVAILABLE and isinstance(self.accelerator, TPUAccelerator):
16461648
rank_zero_warn(
16471649
"TPU available but not used. Set the `tpu_cores` flag in your trainer"
16481650
" `Trainer(tpu_cores=8)` or script `--tpu_cores=8`."
16491651
)
16501652

1651-
if (
1652-
_IPU_AVAILABLE
1653-
and self._device_type != _AcceleratorType.IPU
1654-
and not isinstance(self.accelerator, IPUAccelerator)
1655-
):
1653+
if _IPU_AVAILABLE and not isinstance(self.accelerator, IPUAccelerator):
16561654
rank_zero_warn(
16571655
"IPU available but not used. Set the `ipus` flag in your trainer"
16581656
" `Trainer(ipus=8)` or script `--ipus=8`."

0 commit comments

Comments
 (0)