Skip to content

Commit a52a6ea

Browse files
kaushikb11Bordarohitgr7carmoccapre-commit-ci[bot]
authored
Add support for pluggable Accelerators (#12030)
Co-authored-by: Jirka Borovec <[email protected]> Co-authored-by: Rohit Gupta <[email protected]> Co-authored-by: Carlos Mocholí <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Carlos Mocholi <[email protected]>
1 parent a9024ce commit a52a6ea

File tree

14 files changed

+165
-71
lines changed

14 files changed

+165
-71
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
137137
- Added `estimated_stepping_batches` property to `Trainer` ([#11599](https://github.com/PyTorchLightning/pytorch-lightning/pull/11599))
138138

139139

140+
- Added support for pluggable Accelerators ([#12030](https://github.com/PyTorchLightning/pytorch-lightning/pull/12030))
141+
142+
140143
### Changed
141144

142145
- Make `benchmark` flag optional and set its value based on the deterministic flag ([#11944](https://github.com/PyTorchLightning/pytorch-lightning/pull/11944))

docs/source/extensions/accelerator.rst

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ Each Accelerator gets two plugins upon initialization:
2020
One to handle differences from the training routine and one to handle different precisions.
2121

2222
.. testcode::
23+
:skipif: torch.cuda.device_count() < 2
2324

2425
from pytorch_lightning import Trainer
2526
from pytorch_lightning.accelerators import GPUAccelerator
@@ -28,8 +29,8 @@ One to handle differences from the training routine and one to handle different
2829

2930
accelerator = GPUAccelerator()
3031
precision_plugin = NativeMixedPrecisionPlugin(precision=16, device="cuda")
31-
training_type_plugin = DDPStrategy(accelerator=accelerator, precision_plugin=precision_plugin)
32-
trainer = Trainer(strategy=training_type_plugin)
32+
training_strategy = DDPStrategy(accelerator=accelerator, precision_plugin=precision_plugin)
33+
trainer = Trainer(strategy=training_strategy, devices=2)
3334

3435

3536
We expose Accelerators and Plugins mainly for expert users who want to extend Lightning to work with new

pytorch_lightning/accelerators/accelerator.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,16 @@ def get_device_stats(self, device: Union[str, torch.device]) -> Dict[str, Any]:
5555
"""
5656
raise NotImplementedError
5757

58+
@staticmethod
59+
@abstractmethod
60+
def parse_devices(devices: Any) -> Any:
61+
"""Accelerator device parsing logic."""
62+
63+
@staticmethod
64+
@abstractmethod
65+
def get_parallel_devices(devices: Any) -> Any:
66+
"""Gets parallel devices for the Accelerator."""
67+
5868
@staticmethod
5969
@abstractmethod
6070
def auto_device_count() -> int:

pytorch_lightning/accelerators/cpu.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,13 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
from __future__ import annotations
15-
16-
from typing import Any
14+
from typing import Any, Dict, List, Union
1715

1816
import torch
1917

2018
from pytorch_lightning.accelerators.accelerator import Accelerator
2119
from pytorch_lightning.utilities.exceptions import MisconfigurationException
20+
from pytorch_lightning.utilities.rank_zero import rank_zero_warn
2221
from pytorch_lightning.utilities.types import _DEVICE
2322

2423

@@ -35,10 +34,25 @@ def setup_environment(self, root_device: torch.device) -> None:
3534
if root_device.type != "cpu":
3635
raise MisconfigurationException(f"Device should be CPU, got {root_device} instead.")
3736

38-
def get_device_stats(self, device: _DEVICE) -> dict[str, Any]:
37+
def get_device_stats(self, device: _DEVICE) -> Dict[str, Any]:
3938
"""CPU device stats aren't supported yet."""
4039
return {}
4140

41+
@staticmethod
42+
def parse_devices(devices: Union[int, str, List[int]]) -> Union[int, str, List[int]]:
43+
"""Accelerator device parsing logic."""
44+
return devices
45+
46+
@staticmethod
47+
def get_parallel_devices(devices: Union[int, str, List[int]]) -> List[torch.device]:
48+
"""Gets parallel devices for the Accelerator."""
49+
if isinstance(devices, int):
50+
return [torch.device("cpu")] * devices
51+
rank_zero_warn(
52+
f"The flag `devices` must be an int with `accelerator='cpu'`, got `devices={devices!r}` instead."
53+
)
54+
return []
55+
4256
@staticmethod
4357
def auto_device_count() -> int:
4458
"""Get the devices when set to auto."""

pytorch_lightning/accelerators/gpu.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,18 +11,17 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
from __future__ import annotations
15-
1614
import logging
1715
import os
1816
import shutil
1917
import subprocess
20-
from typing import Any
18+
from typing import Any, Dict, List, Optional, Union
2119

2220
import torch
2321

2422
import pytorch_lightning as pl
2523
from pytorch_lightning.accelerators.accelerator import Accelerator
24+
from pytorch_lightning.utilities import device_parser
2625
from pytorch_lightning.utilities.exceptions import MisconfigurationException
2726
from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_8
2827
from pytorch_lightning.utilities.types import _DEVICE
@@ -44,7 +43,7 @@ def setup_environment(self, root_device: torch.device) -> None:
4443
raise MisconfigurationException(f"Device should be GPU, got {root_device} instead")
4544
torch.cuda.set_device(root_device)
4645

47-
def setup(self, trainer: pl.Trainer) -> None:
46+
def setup(self, trainer: "pl.Trainer") -> None:
4847
# TODO refactor input from trainer to local_rank @four4fish
4948
self.set_nvidia_flags(trainer.local_rank)
5049
# clear cache before training
@@ -58,7 +57,7 @@ def set_nvidia_flags(local_rank: int) -> None:
5857
devices = os.getenv("CUDA_VISIBLE_DEVICES", all_gpu_ids)
5958
_log.info(f"LOCAL_RANK: {local_rank} - CUDA_VISIBLE_DEVICES: [{devices}]")
6059

61-
def get_device_stats(self, device: _DEVICE) -> dict[str, Any]:
60+
def get_device_stats(self, device: _DEVICE) -> Dict[str, Any]:
6261
"""Gets stats for the given GPU device.
6362
6463
Args:
@@ -75,6 +74,16 @@ def get_device_stats(self, device: _DEVICE) -> dict[str, Any]:
7574
return torch.cuda.memory_stats(device)
7675
return get_nvidia_gpu_stats(device)
7776

77+
@staticmethod
78+
def parse_devices(devices: Union[int, str, List[int]]) -> Optional[List[int]]:
79+
"""Accelerator device parsing logic."""
80+
return device_parser.parse_gpu_ids(devices)
81+
82+
@staticmethod
83+
def get_parallel_devices(devices: List[int]) -> List[torch.device]:
84+
"""Gets parallel devices for the Accelerator."""
85+
return [torch.device("cuda", i) for i in devices]
86+
7887
@staticmethod
7988
def auto_device_count() -> int:
8089
"""Get the devices when set to auto."""
@@ -85,7 +94,7 @@ def is_available() -> bool:
8594
return torch.cuda.device_count() > 0
8695

8796

88-
def get_nvidia_gpu_stats(device: _DEVICE) -> dict[str, float]:
97+
def get_nvidia_gpu_stats(device: _DEVICE) -> Dict[str, float]:
8998
"""Get GPU stats including memory, fan speed, and temperature from nvidia-smi.
9099
91100
Args:

pytorch_lightning/accelerators/ipu.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
from typing import Any, Dict, Union
14+
from typing import Any, Dict, List, Union
1515

1616
import torch
1717

@@ -26,6 +26,16 @@ def get_device_stats(self, device: Union[str, torch.device]) -> Dict[str, Any]:
2626
"""IPU device stats aren't supported yet."""
2727
return {}
2828

29+
@staticmethod
30+
def parse_devices(devices: int) -> int:
31+
"""Accelerator device parsing logic."""
32+
return devices
33+
34+
@staticmethod
35+
def get_parallel_devices(devices: int) -> List[int]:
36+
"""Gets parallel devices for the Accelerator."""
37+
return list(range(devices))
38+
2939
@staticmethod
3040
def auto_device_count() -> int:
3141
"""Get the devices when set to auto."""

pytorch_lightning/accelerators/tpu.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,12 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
from typing import Any, Dict, Union
14+
from typing import Any, Dict, List, Optional, Union
1515

1616
import torch
1717

1818
from pytorch_lightning.accelerators.accelerator import Accelerator
19+
from pytorch_lightning.utilities import device_parser
1920
from pytorch_lightning.utilities.imports import _TPU_AVAILABLE, _XLA_AVAILABLE
2021

2122
if _XLA_AVAILABLE:
@@ -43,6 +44,18 @@ def get_device_stats(self, device: Union[str, torch.device]) -> Dict[str, Any]:
4344
}
4445
return device_stats
4546

47+
@staticmethod
48+
def parse_devices(devices: Union[int, str, List[int]]) -> Optional[Union[int, List[int]]]:
49+
"""Accelerator device parsing logic."""
50+
return device_parser.parse_tpu_cores(devices)
51+
52+
@staticmethod
53+
def get_parallel_devices(devices: Union[int, List[int]]) -> List[int]:
54+
"""Gets parallel devices for the Accelerator."""
55+
if isinstance(devices, int):
56+
return list(range(devices))
57+
return devices
58+
4659
@staticmethod
4760
def auto_device_count() -> int:
4861
"""Get the devices when set to auto."""

pytorch_lightning/trainer/connectors/accelerator_connector.py

Lines changed: 24 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,7 @@ def __init__(
168168
self._precision_flag: Optional[Union[int, str]] = None
169169
self._precision_plugin_flag: Optional[PrecisionPlugin] = None
170170
self._cluster_environment_flag: Optional[Union[ClusterEnvironment, str]] = None
171+
self._parallel_devices: List[Union[int, torch.device]] = []
171172
self.checkpoint_io: Optional[CheckpointIO] = None
172173
self._amp_type_flag: Optional[LightningEnum] = None
173174
self._amp_level_flag: Optional[str] = amp_level
@@ -361,6 +362,7 @@ def _check_config_and_set_final_flags(
361362
self._accelerator_flag = "cpu"
362363
if self._strategy_flag.parallel_devices[0].type == "cuda":
363364
self._accelerator_flag = "gpu"
365+
self._parallel_devices = self._strategy_flag.parallel_devices
364366

365367
amp_type = amp_type if isinstance(amp_type, str) else None
366368
self._amp_type_flag = AMPType.from_str(amp_type)
@@ -387,7 +389,7 @@ def _check_device_config_and_set_final_flags(
387389
devices, num_processes, gpus, ipus, tpu_cores
388390
)
389391

390-
if self._devices_flag in ([], 0, "0", "0,"):
392+
if self._devices_flag in ([], 0, "0"):
391393
rank_zero_warn(f"You passed `devices={devices}`, switching to `cpu` accelerator")
392394
self._accelerator_flag = "cpu"
393395

@@ -408,10 +410,8 @@ def _map_deprecated_devices_specfic_info_to_accelerator_and_device_flag(
408410
"""Sets the `devices_flag` and `accelerator_flag` based on num_processes, gpus, ipus, tpu_cores."""
409411
self._gpus: Optional[Union[List[int], str, int]] = gpus
410412
self._tpu_cores: Optional[Union[List[int], str, int]] = tpu_cores
411-
gpus = device_parser.parse_gpu_ids(gpus)
412-
tpu_cores = device_parser.parse_tpu_cores(tpu_cores)
413413
deprecated_devices_specific_flag = num_processes or gpus or ipus or tpu_cores
414-
if deprecated_devices_specific_flag and deprecated_devices_specific_flag not in (0, "0"):
414+
if deprecated_devices_specific_flag and deprecated_devices_specific_flag not in ([], 0, "0"):
415415
if devices:
416416
# TODO: @awaelchli improve error message
417417
rank_zero_warn(
@@ -456,51 +456,34 @@ def _choose_accelerator(self) -> str:
456456

457457
def _set_parallel_devices_and_init_accelerator(self) -> None:
458458
# TODO add device availability check
459-
self._parallel_devices: List[Union[int, torch.device]] = []
460-
461459
if isinstance(self._accelerator_flag, Accelerator):
462460
self.accelerator: Accelerator = self._accelerator_flag
463-
elif self._accelerator_flag == "tpu":
464-
self.accelerator = TPUAccelerator()
465-
self._set_devices_flag_if_auto_passed()
466-
if isinstance(self._devices_flag, int):
467-
self._parallel_devices = list(range(self._devices_flag))
468-
else:
469-
self._parallel_devices = self._devices_flag # type: ignore[assignment]
470-
471-
elif self._accelerator_flag == "ipu":
472-
self.accelerator = IPUAccelerator()
473-
self._set_devices_flag_if_auto_passed()
474-
if isinstance(self._devices_flag, int):
475-
self._parallel_devices = list(range(self._devices_flag))
476-
477-
elif self._accelerator_flag == "gpu":
478-
self.accelerator = GPUAccelerator()
479-
self._set_devices_flag_if_auto_passed()
480-
if isinstance(self._devices_flag, int) or isinstance(self._devices_flag, str):
481-
self._devices_flag = int(self._devices_flag)
482-
self._parallel_devices = (
483-
[torch.device("cuda", i) for i in device_parser.parse_gpu_ids(self._devices_flag)] # type: ignore
484-
if self._devices_flag != 0
485-
else []
461+
else:
462+
ACCELERATORS = {
463+
"cpu": CPUAccelerator,
464+
"gpu": GPUAccelerator,
465+
"tpu": TPUAccelerator,
466+
"ipu": IPUAccelerator,
467+
}
468+
assert self._accelerator_flag is not None
469+
self._accelerator_flag = self._accelerator_flag.lower()
470+
if self._accelerator_flag not in ACCELERATORS:
471+
raise MisconfigurationException(
472+
"When passing string value for the `accelerator` argument of `Trainer`,"
473+
f" it can only be one of {list(ACCELERATORS)}."
486474
)
487-
else:
488-
self._parallel_devices = [torch.device("cuda", i) for i in self._devices_flag] # type: ignore
475+
accelerator_class = ACCELERATORS[self._accelerator_flag]
476+
self.accelerator = accelerator_class() # type: ignore[abstract]
489477

490-
elif self._accelerator_flag == "cpu":
491-
self.accelerator = CPUAccelerator()
492-
self._set_devices_flag_if_auto_passed()
493-
if isinstance(self._devices_flag, int):
494-
self._parallel_devices = [torch.device("cpu")] * self._devices_flag
495-
else:
496-
rank_zero_warn(
497-
"The flag `devices` must be an int with `accelerator='cpu'`,"
498-
f" got `devices={self._devices_flag}` instead."
499-
)
478+
self._set_devices_flag_if_auto_passed()
500479

501480
self._gpus = self._devices_flag if not self._gpus else self._gpus
502481
self._tpu_cores = self._devices_flag if not self._tpu_cores else self._tpu_cores
503482

483+
self._devices_flag = self.accelerator.parse_devices(self._devices_flag)
484+
if not self._parallel_devices:
485+
self._parallel_devices = self.accelerator.get_parallel_devices(self._devices_flag)
486+
504487
def _set_devices_flag_if_auto_passed(self) -> None:
505488
if self._devices_flag == "auto" or not self._devices_flag:
506489
self._devices_flag = self.accelerator.auto_device_count()

pytorch_lightning/trainer/connectors/data_connector.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from torch.utils.data.distributed import DistributedSampler
2222

2323
import pytorch_lightning as pl
24-
from pytorch_lightning.accelerators import IPUAccelerator
24+
from pytorch_lightning.accelerators.ipu import IPUAccelerator
2525
from pytorch_lightning.overrides.distributed import UnrepeatedDistributedSampler
2626
from pytorch_lightning.strategies import DDPSpawnStrategy
2727
from pytorch_lightning.trainer.states import RunningStage, TrainerFn

tests/accelerators/test_accelerator_connector.py

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -341,6 +341,14 @@ def creates_processes_externally(self) -> bool:
341341
@mock.patch("pytorch_lightning.strategies.DDPStrategy.setup_distributed", autospec=True)
342342
def test_custom_accelerator(device_count_mock, setup_distributed_mock):
343343
class Accel(Accelerator):
344+
@staticmethod
345+
def parse_devices(devices):
346+
return devices
347+
348+
@staticmethod
349+
def get_parallel_devices(devices):
350+
return [torch.device("cpu")] * devices
351+
344352
@staticmethod
345353
def auto_device_count() -> int:
346354
return 1
@@ -413,10 +421,17 @@ def test_ipython_incompatible_backend_error(_, monkeypatch):
413421
Trainer(strategy="dp")
414422

415423

416-
@pytest.mark.parametrize("trainer_kwargs", [{}, dict(strategy="dp", accelerator="gpu"), dict(accelerator="tpu")])
417-
def test_ipython_compatible_backend(trainer_kwargs, monkeypatch):
424+
@mock.patch("torch.cuda.device_count", return_value=2)
425+
def test_ipython_compatible_dp_strategy_gpu(_, monkeypatch):
426+
monkeypatch.setattr(pytorch_lightning.utilities, "_IS_INTERACTIVE", True)
427+
trainer = Trainer(strategy="dp", accelerator="gpu")
428+
assert trainer.strategy.launcher is None or trainer.strategy.launcher.is_interactive_compatible
429+
430+
431+
@mock.patch("pytorch_lightning.accelerators.tpu.TPUAccelerator.parse_devices", return_value=8)
432+
def test_ipython_compatible_strategy_tpu(_, monkeypatch):
418433
monkeypatch.setattr(pytorch_lightning.utilities, "_IS_INTERACTIVE", True)
419-
trainer = Trainer(**trainer_kwargs)
434+
trainer = Trainer(accelerator="tpu")
420435
assert trainer.strategy.launcher is None or trainer.strategy.launcher.is_interactive_compatible
421436

422437

@@ -883,10 +898,9 @@ def test_strategy_choice_ddp_cpu_slurm(device_count_mock, setup_distributed_mock
883898
assert trainer.strategy.local_rank == 0
884899

885900

886-
def test_unsupported_tpu_choice(monkeypatch):
887-
import pytorch_lightning.utilities.imports as imports
901+
@mock.patch("pytorch_lightning.accelerators.tpu.TPUAccelerator.parse_devices", return_value=8)
902+
def test_unsupported_tpu_choice(mock_devices):
888903

889-
monkeypatch.setattr(imports, "_XLA_AVAILABLE", True)
890904
with pytest.raises(MisconfigurationException, match=r"accelerator='tpu', precision=64\)` is not implemented"):
891905
Trainer(accelerator="tpu", precision=64)
892906

0 commit comments

Comments
 (0)