Skip to content

Add support for (accelerator='cpu'|'gpu'|'tpu'|'ipu'|'auto') #7808

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 53 commits into from
Jul 9, 2021
Merged
Show file tree
Hide file tree
Changes from 28 commits
Commits
Show all changes
53 commits
Select commit Hold shift + click to select a range
d1b21ec
Add support for (accelerator='cpu'|'gpu'|'tpu'|'auto')
kaushikb11 Jun 2, 2021
64a7c32
GPU Available
kaushikb11 Jun 2, 2021
ae6876a
Add new properties
kaushikb11 Jun 2, 2021
4487353
Fix
kaushikb11 Jun 2, 2021
08c3520
Add validate accelerate type
kaushikb11 Jun 2, 2021
06c5e9e
Tested on TPUs
kaushikb11 Jun 2, 2021
f9024cd
Update for CPU device
kaushikb11 Jun 10, 2021
cb549d0
Update for auto
kaushikb11 Jun 10, 2021
86abd61
Add tests
kaushikb11 Jun 10, 2021
b8b86a3
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 10, 2021
2ea22df
Add tests
kaushikb11 Jun 10, 2021
32bc1b9
Add tests
kaushikb11 Jun 10, 2021
6c7cf45
Add tests
kaushikb11 Jun 10, 2021
a1c88da
Add exception
kaushikb11 Jun 10, 2021
1752ae5
Update changelog
kaushikb11 Jun 10, 2021
9eb686d
Merge branch 'master' into accelerator/auto
kaushikb11 Jun 10, 2021
470cacf
Update
kaushikb11 Jun 10, 2021
b5d4d3f
Merge branch 'accelerator/auto' of https://github.com/kaushikb11/pyto…
kaushikb11 Jun 10, 2021
d15e394
Update
kaushikb11 Jun 10, 2021
0588d24
Merge branch 'master' into accelerator/auto
kaushikb11 Jun 30, 2021
f9b2103
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 30, 2021
4fa334b
Update on_x to has_x
kaushikb11 Jun 30, 2021
3623ed6
Add updates for IPU
kaushikb11 Jun 30, 2021
3e3b575
Add tests for IPUs
kaushikb11 Jun 30, 2021
913093f
Update changelog
kaushikb11 Jun 30, 2021
bbf7f50
Fix typo
kaushikb11 Jun 30, 2021
ad65213
Fix use_cpu
kaushikb11 Jun 30, 2021
a402846
Updates for ipus
kaushikb11 Jun 30, 2021
1a895de
Merge branch 'master' into accelerator/auto
kaushikb11 Jul 6, 2021
d8458a4
Address comments
kaushikb11 Jul 6, 2021
125fcca
Multi gpus + cpu acc
kaushikb11 Jul 7, 2021
b27e548
Fix ipu test
kaushikb11 Jul 7, 2021
e6855a4
Update tpu tests
kaushikb11 Jul 7, 2021
fc21406
Merge branch 'master' into accelerator/auto
kaushikb11 Jul 7, 2021
90658a6
Update ddp_spawn
kaushikb11 Jul 7, 2021
0c5b01e
Merge branch 'accelerator/auto' of https://github.com/kaushikb11/pyto…
kaushikb11 Jul 7, 2021
3163439
Update tests
kaushikb11 Jul 7, 2021
7f7dbaf
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 7, 2021
c71e4e1
Update
kaushikb11 Jul 7, 2021
a9672cb
Merge branch 'accelerator/auto' of https://github.com/kaushikb11/pyto…
kaushikb11 Jul 7, 2021
7a5ff5b
Use has_tpu
kaushikb11 Jul 7, 2021
deb6770
Merge branch 'master' into accelerator/auto
kaushikb11 Jul 8, 2021
766ef1c
Add update_device_type_if_ipu_plugin
kaushikb11 Jul 9, 2021
6ce6ee8
Merge branch 'accelerator/auto' of https://github.com/kaushikb11/pyto…
kaushikb11 Jul 9, 2021
53baf92
Fix IPU tests
Jul 9, 2021
134a90e
Cleanup
Jul 9, 2021
8b28893
Merge branch 'master' into accelerator/auto
Jul 9, 2021
5aa1a03
Update pytorch_lightning/trainer/connectors/accelerator_connector.py
kaushikb11 Jul 9, 2021
d45e89c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 9, 2021
5337414
Update pytorch_lightning/trainer/connectors/accelerator_connector.py
kaushikb11 Jul 9, 2021
8d3d2f0
Update pytorch_lightning/trainer/connectors/accelerator_connector.py
kaushikb11 Jul 9, 2021
c39930c
Update pytorch_lightning/trainer/connectors/accelerator_connector.py
kaushikb11 Jul 9, 2021
a2847ec
Update pytorch_lightning/trainer/connectors/accelerator_connector.py
kaushikb11 Jul 9, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added `save_config_filename` init argument to `LightningCLI` to ease resolving name conflicts ([#7741](https://github.com/PyTorchLightning/pytorch-lightning/pull/7741))


- Added support for `accelerator='cpu'|'gpu'|'tpu'|'ipu'|'auto'` ([#7808](https://github.com/PyTorchLightning/pytorch-lightning/pull/7808))


- Added `save_config_overwrite` init argument to `LightningCLI` to ease overwriting existing config files ([#8059](https://github.com/PyTorchLightning/pytorch-lightning/pull/8059))


Expand Down
137 changes: 103 additions & 34 deletions pytorch_lightning/trainer/connectors/accelerator_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,9 @@
from pytorch_lightning.utilities import (
_APEX_AVAILABLE,
_HOROVOD_AVAILABLE,
_IPU_AVAILABLE,
_NATIVE_AMP_AVAILABLE,
_TPU_AVAILABLE,
AMPType,
device_parser,
DeviceType,
Expand Down Expand Up @@ -101,6 +103,7 @@ def __init__(
# initialization
self._device_type = DeviceType.CPU
self._distrib_type = None
self._accelerator_type = None

self.num_processes = num_processes
self.tpu_cores = device_parser.parse_tpu_cores(tpu_cores)
Expand Down Expand Up @@ -138,16 +141,18 @@ def __init__(

self.parallel_device_ids = device_parser.parse_gpu_ids(self.gpus)

self.select_accelerator_type()
self.set_distributed_mode()
self.configure_slurm_ddp()

self.handle_given_plugins()
self.validate_accelerator_type()

self._training_type_plugin_resolved = False
self.accelerator = self.select_accelerator()

# override dist backend when using tpus
if self.on_tpu:
if self.use_tpu:
self.distributed_backend = "tpu"

# init flags for SLURM+DDP to work
Expand All @@ -169,6 +174,48 @@ def __init__(

self.replace_sampler_ddp = replace_sampler_ddp

def select_accelerator_type(self) -> None:
if self.distributed_backend == "auto":
if self.has_tpu:
self._accelerator_type = DeviceType.TPU
if self.has_ipu:
self._accelerator_type = DeviceType.IPU
elif self.has_gpu:
self._accelerator_type = DeviceType.GPU
else:
self._accelerator_type = DeviceType.CPU
elif self.distributed_backend == DeviceType.TPU:
if not self.has_tpu:
msg = "TPUs are not available" if not _TPU_AVAILABLE else "you didn't pass `tpu_cores` to `Trainer`"
raise MisconfigurationException(f"You passed `accelerator='tpu'`, but {msg}")
self._accelerator_type = DeviceType.TPU
elif self.distributed_backend == DeviceType.IPU:
if not self.has_ipu:
msg = "IPUs are not available" if not _IPU_AVAILABLE else "you didn't pass `ipus` to `Trainer`"
raise MisconfigurationException(f"You passed `accelerator='ipu'`, but {msg}")
self._accelerator_type = DeviceType.IPU
elif self.distributed_backend == DeviceType.GPU:
if not self.has_gpu:
msg = "GPUs are not available" if not torch.cuda.is_available(
) else "you didn't pass `gpus` to `Trainer`"
raise MisconfigurationException(f"You passed `accelerator='gpu'`, but {msg}")
self._accelerator_type = DeviceType.GPU
elif self.distributed_backend == DeviceType.CPU:
self._accelerator_type = DeviceType.CPU

accelerator_types = DeviceType.__members__.values()
if self.distributed_backend in ["auto", *accelerator_types]:
self.distributed_backend = None

def validate_accelerator_type(self) -> None:
if self._accelerator_type and self._accelerator_type != self._device_type:
raise MisconfigurationException(
f"Mismatch between the requested {self._accelerator_type}"
f" and assigned {self._device_type}"
)
else:
self._accelerator_type = self._device_type

def handle_given_plugins(self) -> None:

training_type = None
Expand Down Expand Up @@ -250,28 +297,49 @@ def cluster_environment(self) -> ClusterEnvironment:
return self._cluster_environment

@property
def on_cpu(self) -> bool:
return self._device_type == DeviceType.CPU
def has_cpu(self) -> bool:
return True

@property
def use_cpu(self) -> bool:
return self._accelerator_type == DeviceType.CPU

@property
def has_gpu(self) -> bool:
# Here, we are not checking for GPU availability, but instead if User has passed
# `gpus` to Trainer for training.
gpus = self.parallel_device_ids
return gpus is not None and len(gpus) > 0 and torch.cuda.is_available()

@property
def on_tpu(self) -> bool:
def use_gpu(self) -> bool:
return self._accelerator_type == DeviceType.GPU and self.has_gpu

@property
def has_tpu(self) -> bool:
# Here, we are not checking for TPU availability, but instead if User has passed
# `tpu_cores` to Trainer for training.
return self.tpu_cores is not None

@property
def on_ipu(self) -> bool:
return self.ipus is not None
def use_tpu(self) -> bool:
return self._accelerator_type == DeviceType.TPU and self.has_tpu

@property
def tpu_id(self) -> Optional[int]:
if self.on_tpu and isinstance(self.tpu_cores, list):
if self.use_tpu and isinstance(self.tpu_cores, list):
return self.tpu_cores[0]

return None

@property
def on_gpu(self) -> bool:
gpus = self.parallel_device_ids
return gpus is not None and len(gpus) > 0 and torch.cuda.is_available()
def has_ipu(self) -> bool:
# Here, we are not checking for IPU availability, but instead if User has passed
# `ipus` to Trainer for training.
return self.ipus is not None

@property
def use_ipu(self) -> bool:
return self._accelerator_type == DeviceType.IPU and self.has_ipu

@property
def use_dp(self) -> bool:
Expand Down Expand Up @@ -313,10 +381,10 @@ def _is_fully_sharded_training_type(self) -> bool:
def is_distributed(self) -> bool:
# Used for custom plugins.
# Custom plugins should implement is_distributed property.
if hasattr(self.training_type_plugin, 'is_distributed') and not self.on_tpu:
if hasattr(self.training_type_plugin, 'is_distributed') and not self.use_tpu:
return self.training_type_plugin.is_distributed
is_distributed = self.use_ddp or self.use_ddp2 or self.use_horovod
if self.on_tpu:
if self.use_tpu:
is_distributed |= self.training_type_plugin.is_distributed
return is_distributed

Expand All @@ -329,14 +397,14 @@ def num_gpus(self) -> int:

@property
def parallel_devices(self) -> List[Union[torch.device, int]]:
if self.on_gpu:
if self.use_gpu:
devices = [torch.device("cuda", i) for i in self.parallel_device_ids]
elif self.on_tpu:
elif self.use_tpu:
# explicitly don't make a tpu device here!
# https://github.com/PyTorchLightning/pytorch-lightning/issues/3169
if isinstance(self.tpu_cores, int):
devices = list(range(self.tpu_cores))
elif self.on_ipu:
elif self.use_ipu:
if isinstance(self.ipus, int):
devices = list(range(self.ipus))
else:
Expand Down Expand Up @@ -371,7 +439,7 @@ def select_precision_plugin(self) -> PrecisionPlugin:
# set precision type
self.amp_type = AMPType.from_str(self.amp_type)

if self.on_ipu:
if self.use_ipu:
return IPUPrecisionPlugin(self.precision)

if self._distrib_type == DistributedType.DEEPSPEED or isinstance(self._training_type_plugin, DeepSpeedPlugin):
Expand All @@ -382,11 +450,11 @@ def select_precision_plugin(self) -> PrecisionPlugin:
if self.precision == 64:
return DoublePrecisionPlugin()
if self.precision == 16:
if self.on_tpu:
if self.use_tpu:
return TPUHalfPrecisionPlugin()

if self.amp_type == AMPType.NATIVE:
if self.on_cpu:
if self.use_cpu:
raise MisconfigurationException(
"You have asked for native AMP on CPU, but AMP is only available on GPU."
)
Expand Down Expand Up @@ -442,8 +510,8 @@ def select_training_type_plugin(self) -> TrainingTypePlugin:
use_torchelastic_ddp = self.use_ddp and TorchElasticEnvironment.is_using_torchelastic()
use_kubeflow_ddp = self.use_ddp and KubeflowEnvironment.is_using_kubeflow()
use_ddp_spawn = self._distrib_type == DistributedType.DDP_SPAWN
use_ddp_cpu_spawn = self.use_ddp and self.on_cpu
use_tpu_spawn = self.on_tpu and self._distrib_type == DistributedType.TPU_SPAWN
use_ddp_cpu_spawn = self.use_ddp and self.use_cpu
use_tpu_spawn = self.use_tpu and self._distrib_type == DistributedType.TPU_SPAWN
use_ddp_cpu_torch_elastic = use_ddp_cpu_spawn and TorchElasticEnvironment.is_using_torchelastic()
use_ddp_cpu_kubeflow = use_ddp_cpu_spawn and KubeflowEnvironment.is_using_kubeflow()
use_ddp_cpu_slurm = use_ddp_cpu_spawn and self.is_slurm_managing_tasks
Expand Down Expand Up @@ -482,13 +550,13 @@ def select_training_type_plugin(self) -> TrainingTypePlugin:
plugin = DataParallelPlugin(parallel_devices=self.parallel_devices)
elif self.use_horovod:
plugin = HorovodPlugin(parallel_devices=self.parallel_devices)
elif self.on_tpu and isinstance(self.tpu_cores, list):
elif self.use_tpu and isinstance(self.tpu_cores, list):
plugin = SingleTPUPlugin(self.tpu_id)
elif self.on_ipu:
elif self.use_ipu:
plugin = IPUPlugin(parallel_devices=self.parallel_devices)
else:
single_gpu_ordinal = device_parser.determine_root_gpu_device(self.parallel_device_ids)
plugin = SingleDevicePlugin(device=torch.device(f"cuda:{single_gpu_ordinal}" if self.on_gpu else "cpu"))
plugin = SingleDevicePlugin(device=torch.device(f"cuda:{single_gpu_ordinal}" if self.use_gpu else "cpu"))
return plugin

def resolve_training_type_plugin(self, training_type: TrainingTypePlugin) -> TrainingTypePlugin:
Expand Down Expand Up @@ -524,11 +592,11 @@ def select_accelerator(self) -> Accelerator:
)
return self.distributed_backend

if self.on_gpu:
if self.use_gpu:
acc_cls = GPUAccelerator
elif self.on_tpu:
elif self.use_tpu:
acc_cls = TPUAccelerator
elif self.on_ipu:
elif self.use_ipu:
acc_cls = IPUAccelerator
else:
acc_cls = CPUAccelerator
Expand Down Expand Up @@ -582,6 +650,9 @@ def set_distributed_mode(self, distributed_backend: Optional[str] = None):
)
self.distributed_backend = "ddp_spawn"

is_cpu_accelerator_type = self._accelerator_type and self._accelerator_type == DeviceType.CPU
_use_cpu = is_cpu_accelerator_type or self.distributed_backend and 'cpu' in self.distributed_backend

# special case with DDP on CPUs
if self.distributed_backend == "ddp_cpu":
self._distrib_type = DistributedType.DDP_SPAWN
Expand All @@ -594,23 +665,21 @@ def set_distributed_mode(self, distributed_backend: Optional[str] = None):
# define the max CPU available
self.num_processes = os.cpu_count()
# special case with TPUs
elif self.distributed_backend == 'tpu' or self.tpu_cores is not None:
elif self.tpu_cores is not None and not _use_cpu:
self._device_type = DeviceType.TPU
if isinstance(self.tpu_cores, int):
self._distrib_type = DistributedType.TPU_SPAWN
elif self.distributed_backend == 'ipu':
elif self.has_ipu and not _use_cpu:
self._device_type = DeviceType.IPU
elif self.distributed_backend and self._distrib_type is None:
self._distrib_type = DistributedType(self.distributed_backend)

# unless you request explicitly for CPU and some GPU are available use them
_on_cpu = self.distributed_backend and 'cpu' in self.distributed_backend
if self.num_gpus > 0 and not _on_cpu:
if self.num_gpus > 0 and not _use_cpu:
self._device_type = DeviceType.GPU

_gpu_distrib_types = (DistributedType.DP, DistributedType.DDP, DistributedType.DDP_SPAWN, DistributedType.DDP2)
# DP and DDP2 cannot run without GPU
if self.num_gpus == 0 and self._distrib_type in _gpu_distrib_types and not _on_cpu:
if self.num_gpus == 0 and self._distrib_type in _gpu_distrib_types and not _use_cpu:
rank_zero_warn(
'You requested distributed training on GPUs, but none is available, so we set backend to `ddp_cpu`.'
)
Expand Down Expand Up @@ -652,7 +721,7 @@ def _set_horovod_backend(self):

# Initialize Horovod to get rank / size info
hvd.init()
if self.on_gpu:
if self.use_gpu:
# Horovod assigns one local GPU per process
self.parallel_device_ids = list(range(hvd.local_size()))
else:
Expand Down
43 changes: 43 additions & 0 deletions tests/accelerators/test_accelerator_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -569,3 +569,46 @@ def test_accelerator_choice_multi_node_gpu(
gpus=gpus,
)
assert isinstance(trainer.training_type_plugin, plugin)


@pytest.mark.skipif(torch.cuda.is_available(), reason="test doesn't require GPU")
def test_accelerator_cpu():

trainer = Trainer(accelerator="cpu")

assert trainer._device_type == "cpu"
assert isinstance(trainer.accelerator, CPUAccelerator)

with pytest.raises(MisconfigurationException, match="You passed `accelerator='gpu'`, but GPUs are not available"):
trainer = Trainer(accelerator="gpu")

with pytest.raises(MisconfigurationException, match="You requested GPUs:"):
trainer = Trainer(accelerator="cpu", gpus=1)


@RunIf(min_gpus=1)
def test_accelerator_gpu():

trainer = Trainer(accelerator="gpu", gpus=1)

assert trainer._device_type == "gpu"
assert isinstance(trainer.accelerator, GPUAccelerator)

with pytest.raises(
MisconfigurationException, match="You passed `accelerator='gpu'`, but you didn't pass `gpus` to `Trainer`"
):
trainer = Trainer(accelerator="gpu")

trainer = Trainer(accelerator="auto", gpus=1)

assert trainer._device_type == "gpu"
assert isinstance(trainer.accelerator, GPUAccelerator)


@RunIf(min_gpus=1)
def test_accelerator_cpu_with_gpus_flag():

trainer = Trainer(accelerator="cpu", gpus=1)

assert trainer._device_type == "cpu"
assert isinstance(trainer.accelerator, CPUAccelerator)
30 changes: 29 additions & 1 deletion tests/accelerators/test_ipu.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import torch.nn.functional as F

from pytorch_lightning import Callback, seed_everything, Trainer
from pytorch_lightning.accelerators import IPUAccelerator
from pytorch_lightning.accelerators import CPUAccelerator, IPUAccelerator
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.plugins import IPUPlugin, IPUPrecisionPlugin
from pytorch_lightning.trainer.states import RunningStage
Expand Down Expand Up @@ -506,3 +506,31 @@ def test_precision_plugin(tmpdir):

plugin = IPUPrecisionPlugin(precision=16)
assert plugin.precision == 16


@RunIf(ipu=True)
def test_accelerator_ipu():

trainer = Trainer(accelerator="ipu", ipus=1)

assert trainer._device_type == "ipu"
assert isinstance(trainer.accelerator, IPUAccelerator)

with pytest.raises(
MisconfigurationException, match="You passed `accelerator='tpu'`, but you didn't pass `ipus` to `Trainer`"
):
trainer = Trainer(accelerator="ipu")

trainer = Trainer(accelerator="auto", ipus=8)

assert trainer._device_type == "ipu"
assert isinstance(trainer.accelerator, IPUAccelerator)


@RunIf(ipu=True)
def test_accelerator_cpu_with_ipus_flag():

trainer = Trainer(accelerator="cpu", ipus=1)

assert trainer._device_type == "cpu"
assert isinstance(trainer.accelerator, CPUAccelerator)
Loading