Skip to content

Commit 825c5db

Browse files
kaushikb11carmoccaethanwharrispre-commit-ci[bot]SeanNaren
authored
Add support for (accelerator='cpu'|'gpu'|'tpu'|'ipu'|'auto') (#7808)
Co-authored-by: Carlos Mocholí <[email protected]> Co-authored-by: Ethan Harris <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: SeanNaren <[email protected]>
1 parent 09ff295 commit 825c5db

File tree

6 files changed

+228
-37
lines changed

6 files changed

+228
-37
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
155155
- Added `LSFEnvironment` for distributed training with the LSF resource manager `jsrun` ([#5102](https://github.com/PyTorchLightning/pytorch-lightning/pull/5102))
156156

157157

158+
- Added support for `accelerator='cpu'|'gpu'|'tpu'|'ipu'|'auto'` ([#7808](https://github.com/PyTorchLightning/pytorch-lightning/pull/7808))
159+
160+
158161
### Changed
159162

160163

pytorch_lightning/trainer/connectors/accelerator_connector.py

Lines changed: 108 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,9 @@
6161
from pytorch_lightning.utilities import (
6262
_APEX_AVAILABLE,
6363
_HOROVOD_AVAILABLE,
64+
_IPU_AVAILABLE,
6465
_NATIVE_AMP_AVAILABLE,
66+
_TPU_AVAILABLE,
6567
AMPType,
6668
device_parser,
6769
DeviceType,
@@ -101,6 +103,7 @@ def __init__(
101103
# initialization
102104
self._device_type = DeviceType.CPU
103105
self._distrib_type = None
106+
self._accelerator_type = None
104107

105108
self.num_processes = num_processes
106109
# `gpus` is the input passed to the Trainer, whereas `gpu_ids` is a list of parsed gpu ids.
@@ -133,16 +136,19 @@ def __init__(
133136

134137
self.plugins = plugins
135138

139+
self.select_accelerator_type()
136140
self.set_distributed_mode()
137141
self.configure_slurm_ddp()
138142

139143
self.handle_given_plugins()
144+
self.update_device_type_if_ipu_plugin()
145+
self.validate_accelerator_type()
140146

141147
self._training_type_plugin_resolved = False
142148
self.accelerator = self.select_accelerator()
143149

144150
# override dist backend when using tpus
145-
if self.on_tpu:
151+
if self.use_tpu:
146152
self.distributed_backend = "tpu"
147153

148154
# init flags for SLURM+DDP to work
@@ -164,6 +170,45 @@ def __init__(
164170

165171
self.replace_sampler_ddp = replace_sampler_ddp
166172

173+
def select_accelerator_type(self) -> None:
174+
if self.distributed_backend == "auto":
175+
if self.has_tpu:
176+
self._accelerator_type = DeviceType.TPU
177+
elif self.has_ipu:
178+
self._accelerator_type = DeviceType.IPU
179+
elif self.has_gpu:
180+
self._accelerator_type = DeviceType.GPU
181+
else:
182+
self._accelerator_type = DeviceType.CPU
183+
elif self.distributed_backend == DeviceType.TPU:
184+
if not self.has_tpu:
185+
msg = "TPUs are not available" if not _TPU_AVAILABLE else "you didn't pass `tpu_cores` to `Trainer`"
186+
raise MisconfigurationException(f"You passed `accelerator='tpu'`, but {msg}.")
187+
self._accelerator_type = DeviceType.TPU
188+
elif self.distributed_backend == DeviceType.IPU:
189+
if not self.has_ipu:
190+
msg = "IPUs are not available" if not _IPU_AVAILABLE else "you didn't pass `ipus` to `Trainer`"
191+
raise MisconfigurationException(f"You passed `accelerator='ipu'`, but {msg}.")
192+
self._accelerator_type = DeviceType.IPU
193+
elif self.distributed_backend == DeviceType.GPU:
194+
if not self.has_gpu:
195+
msg = ("you didn't pass `gpus` to `Trainer`" if torch.cuda.is_available() else "GPUs are not available")
196+
raise MisconfigurationException(f"You passed `accelerator='gpu'`, but {msg}.")
197+
self._accelerator_type = DeviceType.GPU
198+
elif self.distributed_backend == DeviceType.CPU:
199+
self._accelerator_type = DeviceType.CPU
200+
201+
if self.distributed_backend in ["auto"] + list(DeviceType):
202+
self.distributed_backend = None
203+
204+
def validate_accelerator_type(self) -> None:
205+
if self._accelerator_type and self._accelerator_type != self._device_type:
206+
raise MisconfigurationException(
207+
f"Mismatch between the requested accelerator type ({self._accelerator_type})"
208+
f" and assigned device type ({self._device_type})."
209+
)
210+
self._accelerator_type = self._device_type
211+
167212
def handle_given_plugins(self) -> None:
168213

169214
training_type = None
@@ -245,28 +290,49 @@ def cluster_environment(self) -> ClusterEnvironment:
245290
return self._cluster_environment
246291

247292
@property
248-
def on_cpu(self) -> bool:
249-
return self._device_type == DeviceType.CPU
293+
def has_cpu(self) -> bool:
294+
return True
295+
296+
@property
297+
def use_cpu(self) -> bool:
298+
return self._accelerator_type == DeviceType.CPU
299+
300+
@property
301+
def has_gpu(self) -> bool:
302+
# Here, we are not checking for GPU availability, but instead if User has passed
303+
# `gpus` to Trainer for training.
304+
gpus = self.parallel_device_ids
305+
return gpus is not None and len(gpus) > 0
306+
307+
@property
308+
def use_gpu(self) -> bool:
309+
return self._accelerator_type == DeviceType.GPU and self.has_gpu
250310

251311
@property
252-
def on_tpu(self) -> bool:
312+
def has_tpu(self) -> bool:
313+
# Here, we are not checking for TPU availability, but instead if User has passed
314+
# `tpu_cores` to Trainer for training.
253315
return self.tpu_cores is not None
254316

255317
@property
256-
def on_ipu(self) -> bool:
257-
return self.ipus is not None or isinstance(self._training_type_plugin, IPUPlugin)
318+
def use_tpu(self) -> bool:
319+
return self._accelerator_type == DeviceType.TPU and self.has_tpu
258320

259321
@property
260322
def tpu_id(self) -> Optional[int]:
261-
if self.on_tpu and isinstance(self.tpu_cores, list):
323+
if self.use_tpu and isinstance(self.tpu_cores, list):
262324
return self.tpu_cores[0]
263-
264325
return None
265326

266327
@property
267-
def on_gpu(self) -> bool:
268-
gpus = self.parallel_device_ids
269-
return gpus is not None and len(gpus) > 0 and torch.cuda.is_available()
328+
def has_ipu(self) -> bool:
329+
# Here, we are not checking for IPU availability, but instead if User has passed
330+
# `ipus` to Trainer for training.
331+
return self.ipus is not None or isinstance(self._training_type_plugin, IPUPlugin)
332+
333+
@property
334+
def use_ipu(self) -> bool:
335+
return self._accelerator_type == DeviceType.IPU and self.has_ipu
270336

271337
@property
272338
def use_dp(self) -> bool:
@@ -308,10 +374,10 @@ def _is_fully_sharded_training_type(self) -> bool:
308374
def is_distributed(self) -> bool:
309375
# Used for custom plugins.
310376
# Custom plugins should implement is_distributed property.
311-
if hasattr(self.training_type_plugin, 'is_distributed') and not self.on_tpu:
377+
if hasattr(self.training_type_plugin, 'is_distributed') and not self.use_tpu:
312378
return self.training_type_plugin.is_distributed
313379
is_distributed = self.use_ddp or self.use_ddp2 or self.use_horovod
314-
if self.on_tpu:
380+
if self.use_tpu:
315381
is_distributed |= self.training_type_plugin.is_distributed
316382
return is_distributed
317383

@@ -332,14 +398,14 @@ def num_ipus(self) -> int:
332398

333399
@property
334400
def parallel_devices(self) -> List[Union[torch.device, int]]:
335-
if self.on_gpu:
401+
if self.use_gpu:
336402
devices = [torch.device("cuda", i) for i in self.parallel_device_ids]
337-
elif self.on_tpu:
403+
elif self.use_tpu:
338404
# explicitly don't make a tpu device here!
339405
# https://github.com/PyTorchLightning/pytorch-lightning/issues/3169
340406
if isinstance(self.tpu_cores, int):
341407
devices = list(range(self.tpu_cores))
342-
elif self.on_ipu:
408+
elif self.use_ipu:
343409
devices = list(range(self.num_ipus))
344410
else:
345411
devices = [torch.device("cpu")] * self.num_processes
@@ -373,7 +439,7 @@ def select_precision_plugin(self) -> PrecisionPlugin:
373439
# set precision type
374440
self.amp_type = AMPType.from_str(self.amp_type)
375441

376-
if self.on_ipu:
442+
if self.use_ipu:
377443
return IPUPrecisionPlugin(self.precision)
378444

379445
if self._distrib_type == DistributedType.DEEPSPEED or isinstance(self._training_type_plugin, DeepSpeedPlugin):
@@ -384,11 +450,11 @@ def select_precision_plugin(self) -> PrecisionPlugin:
384450
if self.precision == 64:
385451
return DoublePrecisionPlugin()
386452
if self.precision == 16:
387-
if self.on_tpu:
453+
if self.use_tpu:
388454
return TPUHalfPrecisionPlugin()
389455

390456
if self.amp_type == AMPType.NATIVE:
391-
if self.on_cpu:
457+
if self.use_cpu:
392458
raise MisconfigurationException(
393459
"You have asked for native AMP on CPU, but AMP is only available on GPU."
394460
)
@@ -444,8 +510,8 @@ def select_training_type_plugin(self) -> TrainingTypePlugin:
444510
use_torchelastic_ddp = self.use_ddp and TorchElasticEnvironment.is_using_torchelastic()
445511
use_kubeflow_ddp = self.use_ddp and KubeflowEnvironment.is_using_kubeflow()
446512
use_ddp_spawn = self._distrib_type == DistributedType.DDP_SPAWN
447-
use_ddp_cpu_spawn = self.use_ddp and self.on_cpu
448-
use_tpu_spawn = self.on_tpu and self._distrib_type == DistributedType.TPU_SPAWN
513+
use_ddp_cpu_spawn = self.use_ddp and self.use_cpu
514+
use_tpu_spawn = self.use_tpu and self._distrib_type == DistributedType.TPU_SPAWN
449515
use_ddp_cpu_torch_elastic = use_ddp_cpu_spawn and TorchElasticEnvironment.is_using_torchelastic()
450516
use_ddp_cpu_kubeflow = use_ddp_cpu_spawn and KubeflowEnvironment.is_using_kubeflow()
451517
use_ddp_cpu_slurm = use_ddp_cpu_spawn and self.is_slurm_managing_tasks
@@ -484,13 +550,13 @@ def select_training_type_plugin(self) -> TrainingTypePlugin:
484550
plugin = DataParallelPlugin(parallel_devices=self.parallel_devices)
485551
elif self.use_horovod:
486552
plugin = HorovodPlugin(parallel_devices=self.parallel_devices)
487-
elif self.on_tpu and isinstance(self.tpu_cores, list):
553+
elif self.use_tpu and isinstance(self.tpu_cores, list):
488554
plugin = SingleTPUPlugin(self.tpu_id)
489-
elif self.on_ipu:
555+
elif self.use_ipu:
490556
plugin = IPUPlugin(parallel_devices=self.parallel_devices)
491557
else:
492558
single_gpu_ordinal = device_parser.determine_root_gpu_device(self.parallel_device_ids)
493-
plugin = SingleDevicePlugin(device=torch.device(f"cuda:{single_gpu_ordinal}" if self.on_gpu else "cpu"))
559+
plugin = SingleDevicePlugin(device=torch.device(f"cuda:{single_gpu_ordinal}" if self.use_gpu else "cpu"))
494560
return plugin
495561

496562
def resolve_training_type_plugin(self, training_type: TrainingTypePlugin) -> TrainingTypePlugin:
@@ -526,11 +592,11 @@ def select_accelerator(self) -> Accelerator:
526592
)
527593
return self.distributed_backend
528594

529-
if self.on_gpu:
595+
if self.use_gpu:
530596
acc_cls = GPUAccelerator
531-
elif self.on_tpu:
597+
elif self.use_tpu:
532598
acc_cls = TPUAccelerator
533-
elif self.on_ipu:
599+
elif self.use_ipu:
534600
acc_cls = IPUAccelerator
535601
else:
536602
acc_cls = CPUAccelerator
@@ -574,12 +640,15 @@ def set_distributed_mode(self, distributed_backend: Optional[str] = None):
574640
if isinstance(self.distributed_backend, Accelerator):
575641
return
576642

643+
is_cpu_accelerator_type = self._accelerator_type and self._accelerator_type == DeviceType.CPU
644+
_use_cpu = is_cpu_accelerator_type or self.distributed_backend and 'cpu' in self.distributed_backend
645+
577646
if self.distributed_backend is None:
578647
if self.has_horovodrun():
579648
self._set_horovod_backend()
580649
elif self.num_gpus == 0 and (self.num_nodes > 1 or self.num_processes > 1):
581650
self._distrib_type = DistributedType.DDP
582-
elif self.num_gpus > 1:
651+
elif self.num_gpus > 1 and not _use_cpu:
583652
rank_zero_warn(
584653
'You requested multiple GPUs but did not specify a backend, e.g.'
585654
' `Trainer(accelerator="dp"|"ddp"|"ddp2")`. Setting `accelerator="ddp_spawn"` for you.'
@@ -598,23 +667,21 @@ def set_distributed_mode(self, distributed_backend: Optional[str] = None):
598667
# define the max CPU available
599668
self.num_processes = os.cpu_count()
600669
# special case with TPUs
601-
elif self.distributed_backend == 'tpu' or self.tpu_cores is not None:
670+
elif self.has_tpu and not _use_cpu:
602671
self._device_type = DeviceType.TPU
603672
if isinstance(self.tpu_cores, int):
604673
self._distrib_type = DistributedType.TPU_SPAWN
605-
elif self.distributed_backend == 'ipu':
674+
elif self.has_ipu and not _use_cpu:
606675
self._device_type = DeviceType.IPU
607676
elif self.distributed_backend and self._distrib_type is None:
608677
self._distrib_type = DistributedType(self.distributed_backend)
609678

610-
# unless you request explicitly for CPU and some GPU are available use them
611-
_on_cpu = self.distributed_backend and 'cpu' in self.distributed_backend
612-
if self.num_gpus > 0 and not _on_cpu:
679+
if self.num_gpus > 0 and not _use_cpu:
613680
self._device_type = DeviceType.GPU
614681

615682
_gpu_distrib_types = (DistributedType.DP, DistributedType.DDP, DistributedType.DDP_SPAWN, DistributedType.DDP2)
616683
# DP and DDP2 cannot run without GPU
617-
if self.num_gpus == 0 and self._distrib_type in _gpu_distrib_types and not _on_cpu:
684+
if self.num_gpus == 0 and self._distrib_type in _gpu_distrib_types and not _use_cpu:
618685
rank_zero_warn(
619686
'You requested distributed training on GPUs, but none is available, so we set backend to `ddp_cpu`.'
620687
)
@@ -656,7 +723,7 @@ def _set_horovod_backend(self):
656723

657724
# Initialize Horovod to get rank / size info
658725
hvd.init()
659-
if self.on_gpu:
726+
if self.has_gpu:
660727
# Horovod assigns one local GPU per process
661728
self.parallel_device_ids = list(range(hvd.local_size()))
662729
else:
@@ -694,6 +761,12 @@ def has_horovodrun() -> bool:
694761
"""Returns True if running with `horovodrun` using Gloo or OpenMPI."""
695762
return "OMPI_COMM_WORLD_RANK" in os.environ or "HOROVOD_RANK" in os.environ
696763

764+
def update_device_type_if_ipu_plugin(self) -> None:
765+
# This allows the poptorch.Options that are passed into the IPUPlugin to be the source of truth,
766+
# which gives users the flexibility to not have to pass `ipus` flag directly to Trainer
767+
if isinstance(self._training_type_plugin, IPUPlugin) and self._device_type != DeviceType.IPU:
768+
self._device_type = DeviceType.IPU
769+
697770
def configure_slurm_ddp(self):
698771
# extract SLURM flag vars
699772
# whenever we have the correct number of tasks, we let slurm manage processes

pytorch_lightning/trainer/trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1226,7 +1226,7 @@ def __setup_profiler(self) -> None:
12261226
def _log_device_info(self) -> None:
12271227
rank_zero_info(f'GPU available: {torch.cuda.is_available()}, used: {self._device_type == DeviceType.GPU}')
12281228

1229-
num_tpu_cores = self.tpu_cores if self.tpu_cores is not None else 0
1229+
num_tpu_cores = self.tpu_cores if self.tpu_cores is not None and self._device_type == DeviceType.TPU else 0
12301230
rank_zero_info(f'TPU available: {_TPU_AVAILABLE}, using: {num_tpu_cores} TPU cores')
12311231

12321232
num_ipus = self.ipus if self.ipus is not None else 0

tests/accelerators/test_accelerator_connector.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -587,3 +587,55 @@ def test_accelerator_choice_multi_node_gpu(
587587
gpus=gpus,
588588
)
589589
assert isinstance(trainer.training_type_plugin, plugin)
590+
591+
592+
@pytest.mark.skipif(torch.cuda.is_available(), reason="test doesn't require GPU")
593+
def test_accelerator_cpu():
594+
595+
trainer = Trainer(accelerator="cpu")
596+
597+
assert trainer._device_type == "cpu"
598+
assert isinstance(trainer.accelerator, CPUAccelerator)
599+
600+
with pytest.raises(MisconfigurationException, match="You passed `accelerator='gpu'`, but GPUs are not available"):
601+
trainer = Trainer(accelerator="gpu")
602+
603+
with pytest.raises(MisconfigurationException, match="You requested GPUs:"):
604+
trainer = Trainer(accelerator="cpu", gpus=1)
605+
606+
607+
@RunIf(min_gpus=1)
608+
def test_accelerator_gpu():
609+
610+
trainer = Trainer(accelerator="gpu", gpus=1)
611+
612+
assert trainer._device_type == "gpu"
613+
assert isinstance(trainer.accelerator, GPUAccelerator)
614+
615+
with pytest.raises(
616+
MisconfigurationException, match="You passed `accelerator='gpu'`, but you didn't pass `gpus` to `Trainer`"
617+
):
618+
trainer = Trainer(accelerator="gpu")
619+
620+
trainer = Trainer(accelerator="auto", gpus=1)
621+
622+
assert trainer._device_type == "gpu"
623+
assert isinstance(trainer.accelerator, GPUAccelerator)
624+
625+
626+
@RunIf(min_gpus=1)
627+
def test_accelerator_cpu_with_gpus_flag():
628+
629+
trainer = Trainer(accelerator="cpu", gpus=1)
630+
631+
assert trainer._device_type == "cpu"
632+
assert isinstance(trainer.accelerator, CPUAccelerator)
633+
634+
635+
@RunIf(min_gpus=2)
636+
def test_accelerator_cpu_with_multiple_gpus():
637+
638+
trainer = Trainer(accelerator="cpu", gpus=2)
639+
640+
assert trainer._device_type == "cpu"
641+
assert isinstance(trainer.accelerator, CPUAccelerator)

0 commit comments

Comments
 (0)