Skip to content

Commit 3861805

Browse files
committed
remove device_type
1 parent 2c05858 commit 3861805

File tree

9 files changed

+23
-27
lines changed

9 files changed

+23
-27
lines changed

pytorch_lightning/callbacks/gpu_stats_monitor.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,9 @@
2828
import torch
2929

3030
import pytorch_lightning as pl
31+
from pytorch_lightning.accelerators import GPUAccelerator
3132
from pytorch_lightning.callbacks.base import Callback
32-
from pytorch_lightning.utilities import _AcceleratorType, rank_zero_deprecation, rank_zero_only
33+
from pytorch_lightning.utilities import rank_zero_deprecation, rank_zero_only
3334
from pytorch_lightning.utilities.exceptions import MisconfigurationException
3435
from pytorch_lightning.utilities.parsing import AttributeDict
3536
from pytorch_lightning.utilities.types import STEP_OUTPUT
@@ -126,7 +127,7 @@ def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: O
126127
if not trainer.logger:
127128
raise MisconfigurationException("Cannot use GPUStatsMonitor callback with Trainer that has no logger.")
128129

129-
if trainer._device_type != _AcceleratorType.GPU:
130+
if not isinstance(trainer.accelerator, GPUAccelerator):
130131
raise MisconfigurationException(
131132
"You are using GPUStatsMonitor but are not running on GPU"
132133
f" since gpus attribute in Trainer is set to {trainer.gpus}."

pytorch_lightning/callbacks/xla_stats_monitor.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,9 @@
2020
"""
2121
import time
2222

23+
from pytorch_lightning.accelerators import TPUAccelerator
2324
from pytorch_lightning.callbacks.base import Callback
24-
from pytorch_lightning.utilities import _AcceleratorType, _TPU_AVAILABLE, rank_zero_deprecation, rank_zero_info
25+
from pytorch_lightning.utilities import _TPU_AVAILABLE, rank_zero_deprecation, rank_zero_info
2526
from pytorch_lightning.utilities.exceptions import MisconfigurationException
2627

2728
if _TPU_AVAILABLE:
@@ -70,7 +71,7 @@ def on_train_start(self, trainer, pl_module) -> None:
7071
if not trainer.logger:
7172
raise MisconfigurationException("Cannot use XLAStatsMonitor callback with Trainer that has no logger.")
7273

73-
if trainer._device_type != _AcceleratorType.TPU:
74+
if not isinstance(trainer.accelerator, TPUAccelerator):
7475
raise MisconfigurationException(
7576
"You are using XLAStatsMonitor but are not running on TPU"
7677
f" since `tpu_cores` attribute in Trainer is set to {trainer.tpu_cores}."

pytorch_lightning/loops/optimization/optimizer_loop.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from torch import Tensor
2020
from torch.optim import Optimizer
2121

22+
from pytorch_lightning.accelerators import TPUAccelerator
2223
from pytorch_lightning.core.optimizer import LightningOptimizer
2324
from pytorch_lightning.loops import Loop
2425
from pytorch_lightning.loops.optimization.closure import AbstractClosure, OutputResult
@@ -30,7 +31,7 @@
3031
)
3132
from pytorch_lightning.profiler import BaseProfiler, PassThroughProfiler
3233
from pytorch_lightning.trainer.progress import OptimizationProgress
33-
from pytorch_lightning.utilities import _AcceleratorType, AMPType
34+
from pytorch_lightning.utilities import AMPType
3435
from pytorch_lightning.utilities.exceptions import MisconfigurationException
3536
from pytorch_lightning.utilities.finite_checks import detect_nan_parameters
3637
from pytorch_lightning.utilities.imports import _TPU_AVAILABLE
@@ -374,7 +375,7 @@ def _optimizer_step(
374375
optimizer,
375376
opt_idx,
376377
train_step_and_backward_closure,
377-
on_tpu=(self.trainer._device_type == _AcceleratorType.TPU and _TPU_AVAILABLE),
378+
on_tpu=(isinstance(self.trainer.accelerator == TPUAccelerator) and _TPU_AVAILABLE),
378379
using_native_amp=(self.trainer.amp_backend is not None and self.trainer.amp_backend == AMPType.NATIVE),
379380
using_lbfgs=is_lbfgs,
380381
)

pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,12 @@
1717
import torch
1818

1919
import pytorch_lightning as pl
20+
from pytorch_lightning.accelerators import GPUAccelerator
2021
from pytorch_lightning.loggers import LightningLoggerBase, LoggerCollection, TensorBoardLogger
2122
from pytorch_lightning.plugins.environments.slurm_environment import SLURMEnvironment
2223
from pytorch_lightning.trainer.connectors.logger_connector.result import _METRICS, _OUT_DICT, _PBAR_DICT
2324
from pytorch_lightning.trainer.states import RunningStage, TrainerFn
24-
from pytorch_lightning.utilities import _AcceleratorType, memory
25+
from pytorch_lightning.utilities import memory
2526
from pytorch_lightning.utilities.apply_func import apply_to_collection, move_data_to_device
2627
from pytorch_lightning.utilities.metrics import metrics_to_scalars
2728
from pytorch_lightning.utilities.warnings import rank_zero_deprecation
@@ -330,7 +331,7 @@ def gpus_metrics(self) -> Dict[str, float]:
330331
.. deprecated:: v1.5
331332
Will be removed in v1.7.
332333
"""
333-
if self.trainer._device_type == _AcceleratorType.GPU and self.log_gpu_memory:
334+
if isinstance(self.trainer.accelerator, GPUAccelerator) and self.log_gpu_memory:
334335
mem_map = memory.get_memory_profile(self.log_gpu_memory)
335336
self._gpus_metrics.update(mem_map)
336337
return self._gpus_metrics

pytorch_lightning/trainer/trainer.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,6 @@
7373
from pytorch_lightning.tuner.lr_finder import _LRFinder
7474
from pytorch_lightning.tuner.tuning import Tuner
7575
from pytorch_lightning.utilities import (
76-
_AcceleratorType,
7776
_IPU_AVAILABLE,
7877
_StrategyType,
7978
_TPU_AVAILABLE,
@@ -1706,9 +1705,9 @@ def should_rank_save_checkpoint(self) -> bool:
17061705
def _distrib_type(self) -> _StrategyType:
17071706
return self._accelerator_connector._distrib_type
17081707

1709-
@property
1710-
def _device_type(self) -> _AcceleratorType:
1711-
return self._accelerator_connector._device_type
1708+
# @property
1709+
# def _device_type(self) -> _AcceleratorType:
1710+
# return self._accelerator_connector._device_type
17121711

17131712
@property
17141713
def num_nodes(self) -> int:

pytorch_lightning/utilities/model_summary.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,8 @@
2323
from torch.utils.hooks import RemovableHandle
2424

2525
import pytorch_lightning as pl
26-
from pytorch_lightning.utilities import _AcceleratorType, AMPType
26+
from pytorch_lightning.accelerators import TPUAccelerator
27+
from pytorch_lightning.utilities import AMPType
2728
from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_8
2829
from pytorch_lightning.utilities.warnings import WarningCache
2930

@@ -264,7 +265,7 @@ def _forward_example_input(self) -> None:
264265
if (
265266
trainer is not None
266267
and trainer.amp_backend == AMPType.NATIVE
267-
and trainer._device_type != _AcceleratorType.TPU
268+
and isinstance(trainer.accelerator, TPUAccelerator)
268269
):
269270
model.forward = torch.cuda.amp.autocast()(model.forward)
270271

tests/accelerators/test_ipu.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from pytorch_lightning.plugins import IPUPlugin, IPUPrecisionPlugin
2525
from pytorch_lightning.trainer.states import RunningStage, TrainerFn
2626
from pytorch_lightning.trainer.supporters import CombinedLoader
27-
from pytorch_lightning.utilities import _AcceleratorType, _IPU_AVAILABLE
27+
from pytorch_lightning.utilities import _IPU_AVAILABLE
2828
from pytorch_lightning.utilities.exceptions import MisconfigurationException
2929
from tests.helpers.boring_model import BoringModel
3030
from tests.helpers.datamodules import ClassifDataModule
@@ -500,7 +500,6 @@ def test_accelerator_ipu():
500500

501501
trainer = Trainer(accelerator="ipu", ipus=1)
502502

503-
assert trainer._device_type == "ipu"
504503
assert isinstance(trainer.accelerator, IPUAccelerator)
505504

506505
with pytest.raises(
@@ -510,7 +509,6 @@ def test_accelerator_ipu():
510509

511510
trainer = Trainer(accelerator="auto", ipus=8)
512511

513-
assert trainer._device_type == "ipu"
514512
assert isinstance(trainer.accelerator, IPUAccelerator)
515513

516514

@@ -519,7 +517,6 @@ def test_accelerator_cpu_with_ipus_flag():
519517

520518
trainer = Trainer(accelerator="cpu", ipus=1)
521519

522-
assert trainer._device_type == "cpu"
523520
assert isinstance(trainer.accelerator, CPUAccelerator)
524521

525522

@@ -538,7 +535,7 @@ def test_accelerator_auto_with_devices_ipu():
538535

539536
trainer = Trainer(accelerator="auto", devices=8)
540537

541-
assert trainer._device_type == "ipu"
538+
assert isinstance(trainer.accelerator, IPUAccelerator)
542539
assert trainer.ipus == 8
543540

544541

@@ -567,11 +564,10 @@ def test_strategy_choice_ipu_plugin(tmpdir):
567564

568565

569566
@RunIf(ipu=True)
570-
def test_device_type_when_training_plugin_ipu_passed(tmpdir):
567+
def test_accelerator_type_when_training_plugin_ipu_passed(tmpdir):
571568

572569
trainer = Trainer(strategy=IPUPlugin(), ipus=8)
573570
assert isinstance(trainer.training_type_plugin, IPUPlugin)
574-
assert trainer._device_type == _AcceleratorType.IPU
575571
assert isinstance(trainer.accelerator, IPUAccelerator)
576572

577573

tests/accelerators/test_tpu.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,6 @@ def test_accelerator_tpu():
8585

8686
trainer = Trainer(accelerator="tpu", tpu_cores=8)
8787

88-
assert trainer._device_type == "tpu"
8988
assert isinstance(trainer.accelerator, TPUAccelerator)
9089

9190
with pytest.raises(
@@ -99,7 +98,6 @@ def test_accelerator_cpu_with_tpu_cores_flag():
9998

10099
trainer = Trainer(accelerator="cpu", tpu_cores=8)
101100

102-
assert trainer._device_type == "cpu"
103101
assert isinstance(trainer.accelerator, CPUAccelerator)
104102

105103

@@ -108,7 +106,6 @@ def test_accelerator_tpu_with_auto():
108106

109107
trainer = Trainer(accelerator="auto", tpu_cores=8)
110108

111-
assert trainer._device_type == "tpu"
112109
assert isinstance(trainer.accelerator, TPUAccelerator)
113110

114111

@@ -127,7 +124,7 @@ def test_accelerator_auto_with_devices_tpu():
127124

128125
trainer = Trainer(accelerator="auto", devices=8)
129126

130-
assert trainer._device_type == "tpu"
127+
assert isinstance(trainer.accelerator, TPUAccelerator)
131128
assert trainer.tpu_cores == 8
132129

133130

tests/models/test_tpu.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from pytorch_lightning.callbacks import EarlyStopping
2727
from pytorch_lightning.plugins import TPUSpawnPlugin
2828
from pytorch_lightning.trainer.connectors.logger_connector.result import _Sync
29-
from pytorch_lightning.utilities import _AcceleratorType, _TPU_AVAILABLE
29+
from pytorch_lightning.utilities import _TPU_AVAILABLE
3030
from pytorch_lightning.utilities.distributed import ReduceOp
3131
from pytorch_lightning.utilities.exceptions import MisconfigurationException
3232
from tests.helpers import BoringModel, RandomDataset
@@ -474,5 +474,4 @@ def test_device_type_when_training_plugin_tpu_passed(tmpdir):
474474

475475
trainer = Trainer(strategy=TPUSpawnPlugin(), tpu_cores=8)
476476
assert isinstance(trainer.training_type_plugin, TPUSpawnPlugin)
477-
assert trainer._device_type == _AcceleratorType.TPU
478477
assert isinstance(trainer.accelerator, TPUAccelerator)

0 commit comments

Comments
 (0)