Skip to content

Commit e00a14b

Browse files
committed
Remove Trainer._device_type
1 parent 2c75a7b commit e00a14b

File tree

10 files changed

+100
-147
lines changed

10 files changed

+100
-147
lines changed

pytorch_lightning/callbacks/xla_stats_monitor.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,9 @@
2121
import time
2222

2323
import pytorch_lightning as pl
24+
from pytorch_lightning.accelerators import TPUAccelerator
2425
from pytorch_lightning.callbacks.base import Callback
25-
from pytorch_lightning.utilities import _AcceleratorType, _TPU_AVAILABLE
26+
from pytorch_lightning.utilities import _TPU_AVAILABLE
2627
from pytorch_lightning.utilities.exceptions import MisconfigurationException
2728
from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation, rank_zero_info
2829

@@ -72,7 +73,7 @@ def on_train_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule")
7273
if not trainer.logger:
7374
raise MisconfigurationException("Cannot use XLAStatsMonitor callback with Trainer that has no logger.")
7475

75-
if trainer._device_type != _AcceleratorType.TPU:
76+
if isinstance(trainer.accelerator, TPUAccelerator):
7677
raise MisconfigurationException(
7778
"You are using XLAStatsMonitor but are not running on TPU"
7879
f" since `tpu_cores` attribute in Trainer is set to {trainer.tpu_cores}."

pytorch_lightning/core/lightning.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,8 +95,6 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
9595
# pointer to the trainer object
9696
self.trainer = None
9797

98-
self._device_type = None
99-
10098
# true if using amp
10199
self.use_amp: bool = False
102100

pytorch_lightning/loops/optimization/optimizer_loop.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from torch import Tensor
2020
from torch.optim import Optimizer
2121

22+
from pytorch_lightning.accelerators import TPUAccelerator
2223
from pytorch_lightning.core.optimizer import LightningOptimizer
2324
from pytorch_lightning.loops import Loop
2425
from pytorch_lightning.loops.optimization.closure import AbstractClosure, OutputResult
@@ -29,10 +30,9 @@
2930
check_finite_loss,
3031
)
3132
from pytorch_lightning.trainer.progress import OptimizationProgress
32-
from pytorch_lightning.utilities import _AcceleratorType, AMPType
33+
from pytorch_lightning.utilities import AMPType
3334
from pytorch_lightning.utilities.exceptions import MisconfigurationException
3435
from pytorch_lightning.utilities.finite_checks import detect_nan_parameters
35-
from pytorch_lightning.utilities.imports import _TPU_AVAILABLE
3636
from pytorch_lightning.utilities.types import STEP_OUTPUT
3737
from pytorch_lightning.utilities.warnings import WarningCache
3838

@@ -369,7 +369,7 @@ def _optimizer_step(
369369
optimizer,
370370
opt_idx,
371371
train_step_and_backward_closure,
372-
on_tpu=(self.trainer._device_type == _AcceleratorType.TPU and _TPU_AVAILABLE),
372+
on_tpu=isinstance(self.trainer.accelerator, TPUAccelerator),
373373
using_native_amp=(self.trainer.amp_backend == AMPType.NATIVE),
374374
using_lbfgs=is_lbfgs,
375375
)

pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,12 @@
1616
import torch
1717

1818
import pytorch_lightning as pl
19+
from pytorch_lightning.accelerators import GPUAccelerator
1920
from pytorch_lightning.loggers import LightningLoggerBase, LoggerCollection, TensorBoardLogger
2021
from pytorch_lightning.plugins.environments.slurm_environment import SLURMEnvironment
2122
from pytorch_lightning.trainer.connectors.logger_connector.result import _METRICS, _OUT_DICT, _PBAR_DICT
2223
from pytorch_lightning.trainer.states import RunningStage
23-
from pytorch_lightning.utilities import _AcceleratorType, memory
24+
from pytorch_lightning.utilities import memory
2425
from pytorch_lightning.utilities.apply_func import apply_to_collection, move_data_to_device
2526
from pytorch_lightning.utilities.metrics import metrics_to_scalars
2627
from pytorch_lightning.utilities.model_helpers import is_overridden
@@ -305,7 +306,7 @@ def gpus_metrics(self) -> Dict[str, float]:
305306
.. deprecated:: v1.5
306307
Will be removed in v1.7.
307308
"""
308-
if self.trainer._device_type == _AcceleratorType.GPU and self.log_gpu_memory:
309+
if isinstance(self.trainer.accelerator, GPUAccelerator) and self.log_gpu_memory:
309310
mem_map = memory.get_memory_profile(self.log_gpu_memory)
310311
self._gpus_metrics.update(mem_map)
311312
return self._gpus_metrics

pytorch_lightning/trainer/trainer.py

Lines changed: 8 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
from torch.utils.data import DataLoader
3131

3232
import pytorch_lightning as pl
33-
from pytorch_lightning.accelerators import Accelerator, IPUAccelerator
33+
from pytorch_lightning.accelerators import Accelerator, GPUAccelerator, IPUAccelerator, TPUAccelerator
3434
from pytorch_lightning.callbacks import Callback, EarlyStopping, ModelCheckpoint, ProgressBarBase
3535
from pytorch_lightning.callbacks.prediction_writer import BasePredictionWriter
3636
from pytorch_lightning.core.datamodule import LightningDataModule
@@ -75,7 +75,6 @@
7575
from pytorch_lightning.tuner.lr_finder import _LRFinder
7676
from pytorch_lightning.tuner.tuning import Tuner
7777
from pytorch_lightning.utilities import (
78-
_AcceleratorType,
7978
_IPU_AVAILABLE,
8079
_TPU_AVAILABLE,
8180
AMPType,
@@ -1716,33 +1715,31 @@ def __setup_profiler(self) -> None:
17161715
self.profiler.setup(stage=self.state.fn._setup_fn, local_rank=local_rank, log_dir=self.log_dir)
17171716

17181717
def _log_device_info(self) -> None:
1719-
rank_zero_info(f"GPU available: {torch.cuda.is_available()}, used: {self._device_type == _AcceleratorType.GPU}")
1718+
rank_zero_info(
1719+
f"GPU available: {torch.cuda.is_available()}, used: {isinstance(self.accelerator, GPUAccelerator)}"
1720+
)
17201721

17211722
num_tpu_cores = (
1722-
self.tpu_cores if self.tpu_cores is not None and self._device_type == _AcceleratorType.TPU else 0
1723+
self.tpu_cores if self.tpu_cores is not None and isinstance(self.accelerator, TPUAccelerator) else 0
17231724
)
17241725
rank_zero_info(f"TPU available: {_TPU_AVAILABLE}, using: {num_tpu_cores} TPU cores")
17251726

17261727
num_ipus = self.ipus if self.ipus is not None else 0
17271728
rank_zero_info(f"IPU available: {_IPU_AVAILABLE}, using: {num_ipus} IPUs")
17281729

1729-
if torch.cuda.is_available() and self._device_type != _AcceleratorType.GPU:
1730+
if torch.cuda.is_available() and not isinstance(self.accelerator, GPUAccelerator):
17301731
rank_zero_warn(
17311732
"GPU available but not used. Set the gpus flag in your trainer `Trainer(gpus=1)` or script `--gpus=1`.",
17321733
category=PossibleUserWarning,
17331734
)
17341735

1735-
if _TPU_AVAILABLE and self._device_type != _AcceleratorType.TPU:
1736+
if _TPU_AVAILABLE and not isinstance(self.accelerator, TPUAccelerator):
17361737
rank_zero_warn(
17371738
"TPU available but not used. Set the `tpu_cores` flag in your trainer"
17381739
" `Trainer(tpu_cores=8)` or script `--tpu_cores=8`."
17391740
)
17401741

1741-
if (
1742-
_IPU_AVAILABLE
1743-
and self._device_type != _AcceleratorType.IPU
1744-
and not isinstance(self.accelerator, IPUAccelerator)
1745-
):
1742+
if _IPU_AVAILABLE and not isinstance(self.accelerator, IPUAccelerator):
17461743
rank_zero_warn(
17471744
"IPU available but not used. Set the `ipus` flag in your trainer"
17481745
" `Trainer(ipus=8)` or script `--ipus=8`."
@@ -1962,10 +1959,6 @@ def should_rank_save_checkpoint(self) -> bool:
19621959
isinstance(strategy, pl.strategies.TPUSpawnStrategy) and strategy.local_rank == 0 or strategy.is_global_zero
19631960
)
19641961

1965-
@property
1966-
def _device_type(self) -> _AcceleratorType:
1967-
return self._accelerator_connector.device_type
1968-
19691962
@property
19701963
def num_nodes(self) -> int:
19711964
return self._accelerator_connector.num_nodes

tests/accelerators/test_accelerator_connector.py

Lines changed: 1 addition & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,6 @@
4242
ParallelStrategy,
4343
SingleDeviceStrategy,
4444
)
45-
from pytorch_lightning.utilities import _AcceleratorType
4645
from pytorch_lightning.utilities.exceptions import MisconfigurationException
4746
from tests.helpers.runif import RunIf
4847

@@ -447,10 +446,7 @@ def test_accelerator_choice_multi_node_gpu(
447446

448447
@mock.patch("torch.cuda.is_available", return_value=False)
449448
def test_accelerator_cpu(_):
450-
451449
trainer = Trainer(accelerator="cpu")
452-
453-
assert trainer._device_type == "cpu"
454450
assert isinstance(trainer.accelerator, CPUAccelerator)
455451

456452
with pytest.raises(MisconfigurationException, match="You requested gpu:"):
@@ -464,36 +460,25 @@ def test_accelerator_cpu(_):
464460

465461
@RunIf(min_gpus=1)
466462
def test_accelerator_gpu():
467-
468463
trainer = Trainer(accelerator="gpu", gpus=1)
469-
470-
assert trainer._device_type == "gpu"
471464
assert isinstance(trainer.accelerator, GPUAccelerator)
472465

473466
trainer = Trainer(accelerator="gpu")
474467
assert isinstance(trainer.accelerator, GPUAccelerator)
475468

476469
trainer = Trainer(accelerator="auto", gpus=1)
477-
478-
assert trainer._device_type == "gpu"
479470
assert isinstance(trainer.accelerator, GPUAccelerator)
480471

481472

482473
@RunIf(min_gpus=1)
483474
def test_accelerator_cpu_with_gpus_flag():
484-
485475
trainer = Trainer(accelerator="cpu", gpus=1)
486-
487-
assert trainer._device_type == "cpu"
488476
assert isinstance(trainer.accelerator, CPUAccelerator)
489477

490478

491479
@RunIf(min_gpus=2)
492480
def test_accelerator_cpu_with_multiple_gpus():
493-
494481
trainer = Trainer(accelerator="cpu", gpus=2)
495-
496-
assert trainer._device_type == "cpu"
497482
assert isinstance(trainer.accelerator, CPUAccelerator)
498483

499484

@@ -532,10 +517,8 @@ def test_accelerator_gpu_with_devices(devices, plugin):
532517

533518
@RunIf(min_gpus=1)
534519
def test_accelerator_auto_with_devices_gpu():
535-
536520
trainer = Trainer(accelerator="auto", devices=1)
537-
538-
assert trainer._device_type == "gpu"
521+
assert isinstance(trainer.accelerator, GPUAccelerator)
539522
assert trainer.gpus == 1
540523

541524

@@ -662,10 +645,8 @@ def test_strategy_choice_gpu_plugin(tmpdir, plugin):
662645
@RunIf(min_gpus=2)
663646
@pytest.mark.parametrize("plugin", [DDPSpawnStrategy, DDPStrategy])
664647
def test_device_type_when_training_plugin_gpu_passed(tmpdir, plugin):
665-
666648
trainer = Trainer(strategy=plugin(), gpus=2)
667649
assert isinstance(trainer.strategy, plugin)
668-
assert trainer._device_type == _AcceleratorType.GPU
669650
assert isinstance(trainer.accelerator, GPUAccelerator)
670651

671652

tests/accelerators/test_ipu.py

Lines changed: 2 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from pytorch_lightning.strategies.ipu import IPUStrategy
2626
from pytorch_lightning.trainer.states import RunningStage, TrainerFn
2727
from pytorch_lightning.trainer.supporters import CombinedLoader
28-
from pytorch_lightning.utilities import _AcceleratorType, _IPU_AVAILABLE
28+
from pytorch_lightning.utilities import _IPU_AVAILABLE
2929
from pytorch_lightning.utilities.exceptions import MisconfigurationException
3030
from tests.helpers.boring_model import BoringModel
3131
from tests.helpers.datamodules import ClassifDataModule
@@ -499,27 +499,19 @@ def test_precision_plugin(tmpdir):
499499

500500
@RunIf(ipu=True)
501501
def test_accelerator_ipu():
502-
503502
trainer = Trainer(accelerator="ipu", ipus=1)
504-
505-
assert trainer._device_type == "ipu"
506503
assert isinstance(trainer.accelerator, IPUAccelerator)
507504

508505
trainer = Trainer(accelerator="ipu")
509506
assert isinstance(trainer.accelerator, IPUAccelerator)
510507

511508
trainer = Trainer(accelerator="auto", ipus=8)
512-
513-
assert trainer._device_type == "ipu"
514509
assert isinstance(trainer.accelerator, IPUAccelerator)
515510

516511

517512
@RunIf(ipu=True)
518513
def test_accelerator_cpu_with_ipus_flag():
519-
520514
trainer = Trainer(accelerator="cpu", ipus=1)
521-
522-
assert trainer._device_type == "cpu"
523515
assert isinstance(trainer.accelerator, CPUAccelerator)
524516

525517

@@ -535,10 +527,8 @@ def test_accelerator_ipu_with_devices():
535527

536528
@RunIf(ipu=True)
537529
def test_accelerator_auto_with_devices_ipu():
538-
539530
trainer = Trainer(accelerator="auto", devices=8)
540-
541-
assert trainer._device_type == "ipu"
531+
assert isinstance(trainer.accelerator, IPUAccelerator)
542532
assert trainer.ipus == 8
543533

544534

@@ -568,10 +558,8 @@ def test_strategy_choice_ipu_plugin(tmpdir):
568558

569559
@RunIf(ipu=True)
570560
def test_device_type_when_training_plugin_ipu_passed(tmpdir):
571-
572561
trainer = Trainer(strategy=IPUStrategy(), ipus=8)
573562
assert isinstance(trainer.strategy, IPUStrategy)
574-
assert trainer._device_type == _AcceleratorType.IPU
575563
assert isinstance(trainer.accelerator, IPUAccelerator)
576564

577565

tests/accelerators/test_tpu.py

Lines changed: 11 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from pytorch_lightning.accelerators.cpu import CPUAccelerator
2525
from pytorch_lightning.accelerators.tpu import TPUAccelerator
2626
from pytorch_lightning.plugins import PrecisionPlugin, TPUPrecisionPlugin, XLACheckpointIO
27-
from pytorch_lightning.strategies import DDPStrategy, TPUSpawnStrategy
27+
from pytorch_lightning.strategies import DDPStrategy, SingleTPUStrategy, TPUSpawnStrategy
2828
from pytorch_lightning.utilities import find_shared_parameters
2929
from pytorch_lightning.utilities.exceptions import MisconfigurationException
3030
from tests.helpers.boring_model import BoringModel, RandomDataset
@@ -86,48 +86,33 @@ def test_accelerator_tpu():
8686
assert TPUAccelerator.is_available()
8787

8888
trainer = Trainer(accelerator="tpu", tpu_cores=8)
89-
90-
assert trainer._device_type == "tpu"
9189
assert isinstance(trainer.accelerator, TPUAccelerator)
90+
assert isinstance(trainer.strategy, TPUSpawnStrategy)
9291

9392
trainer = Trainer(accelerator="tpu")
9493
assert isinstance(trainer.accelerator, TPUAccelerator)
94+
assert isinstance(trainer.strategy, SingleTPUStrategy)
9595

9696

9797
@RunIf(tpu=True)
9898
def test_accelerator_cpu_with_tpu_cores_flag():
99-
10099
trainer = Trainer(accelerator="cpu", tpu_cores=8)
101-
102-
assert trainer._device_type == "cpu"
103100
assert isinstance(trainer.accelerator, CPUAccelerator)
104101

105102

106-
@RunIf(tpu=True)
107-
def test_accelerator_tpu_with_auto():
108-
109-
trainer = Trainer(accelerator="auto", tpu_cores=8)
110-
111-
assert trainer._device_type == "tpu"
112-
assert isinstance(trainer.accelerator, TPUAccelerator)
113-
114-
115-
@RunIf(tpu=True)
116-
def test_accelerator_tpu_with_devices():
117-
118-
trainer = Trainer(accelerator="tpu", devices=8)
119-
120-
assert trainer.tpu_cores == 8
121-
assert isinstance(trainer.strategy, TPUSpawnStrategy)
122-
assert isinstance(trainer.accelerator, TPUAccelerator)
123-
124-
125103
@RunIf(tpu=True)
126104
def test_accelerator_auto_with_devices_tpu():
105+
assert TPUAccelerator.is_available()
127106

128107
trainer = Trainer(accelerator="auto", devices=8)
108+
assert isinstance(trainer.accelerator, TPUAccelerator)
109+
assert isinstance(trainer.strategy, TPUSpawnStrategy)
110+
assert trainer.tpu_cores == 8
129111

130-
assert trainer._device_type == "tpu"
112+
trainer = Trainer(accelerator="auto", devices="auto")
113+
assert isinstance(trainer.accelerator, TPUAccelerator)
114+
assert isinstance(trainer.strategy, TPUSpawnStrategy)
115+
assert trainer.devices == 8
131116
assert trainer.tpu_cores == 8
132117

133118

@@ -328,10 +313,3 @@ def test_mp_device_dataloader_attribute(_):
328313
dataset = RandomDataset(32, 64)
329314
dataloader = TPUSpawnStrategy().process_dataloader(DataLoader(dataset))
330315
assert dataloader.dataset == dataset
331-
332-
333-
@RunIf(tpu=True)
334-
def test_devices_auto_choice_tpu():
335-
trainer = Trainer(accelerator="auto", devices="auto")
336-
assert trainer.devices == 8
337-
assert trainer.tpu_cores == 8

tests/models/test_tpu.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from pytorch_lightning.callbacks import EarlyStopping
2727
from pytorch_lightning.strategies import TPUSpawnStrategy
2828
from pytorch_lightning.trainer.connectors.logger_connector.result import _Sync
29-
from pytorch_lightning.utilities import _AcceleratorType, _TPU_AVAILABLE
29+
from pytorch_lightning.utilities import _TPU_AVAILABLE
3030
from pytorch_lightning.utilities.distributed import ReduceOp
3131
from pytorch_lightning.utilities.exceptions import MisconfigurationException
3232
from tests.helpers import BoringModel, RandomDataset
@@ -469,8 +469,6 @@ def teardown(self, stage):
469469
@RunIf(tpu=True)
470470
@pl_multi_process_test
471471
def test_device_type_when_training_plugin_tpu_passed(tmpdir):
472-
473472
trainer = Trainer(strategy=TPUSpawnStrategy(), tpu_cores=8)
474473
assert isinstance(trainer.strategy, TPUSpawnStrategy)
475-
assert trainer._device_type == _AcceleratorType.TPU
476474
assert isinstance(trainer.accelerator, TPUAccelerator)

0 commit comments

Comments
 (0)