diff --git a/CHANGELOG.md b/CHANGELOG.md index a5a9f88de72da..8e0829fa44068 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -152,6 +152,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added `LSFEnvironment` for distributed training with the LSF resource manager `jsrun` ([#5102](https://github.com/PyTorchLightning/pytorch-lightning/pull/5102)) +- Added support for `accelerator='cpu'|'gpu'|'tpu'|'ipu'|'auto'` ([#7808](https://github.com/PyTorchLightning/pytorch-lightning/pull/7808)) + + ### Changed diff --git a/pytorch_lightning/trainer/connectors/accelerator_connector.py b/pytorch_lightning/trainer/connectors/accelerator_connector.py index 2a643c64f5e64..a9355741a2e6d 100644 --- a/pytorch_lightning/trainer/connectors/accelerator_connector.py +++ b/pytorch_lightning/trainer/connectors/accelerator_connector.py @@ -61,7 +61,9 @@ from pytorch_lightning.utilities import ( _APEX_AVAILABLE, _HOROVOD_AVAILABLE, + _IPU_AVAILABLE, _NATIVE_AMP_AVAILABLE, + _TPU_AVAILABLE, AMPType, device_parser, DeviceType, @@ -101,6 +103,7 @@ def __init__( # initialization self._device_type = DeviceType.CPU self._distrib_type = None + self._accelerator_type = None self.num_processes = num_processes # `gpus` is the input passed to the Trainer, whereas `gpu_ids` is a list of parsed gpu ids. @@ -133,16 +136,19 @@ def __init__( self.plugins = plugins + self.select_accelerator_type() self.set_distributed_mode() self.configure_slurm_ddp() self.handle_given_plugins() + self.update_device_type_if_ipu_plugin() + self.validate_accelerator_type() self._training_type_plugin_resolved = False self.accelerator = self.select_accelerator() # override dist backend when using tpus - if self.on_tpu: + if self.use_tpu: self.distributed_backend = "tpu" # init flags for SLURM+DDP to work @@ -164,6 +170,45 @@ def __init__( self.replace_sampler_ddp = replace_sampler_ddp + def select_accelerator_type(self) -> None: + if self.distributed_backend == "auto": + if self.has_tpu: + self._accelerator_type = DeviceType.TPU + elif self.has_ipu: + self._accelerator_type = DeviceType.IPU + elif self.has_gpu: + self._accelerator_type = DeviceType.GPU + else: + self._accelerator_type = DeviceType.CPU + elif self.distributed_backend == DeviceType.TPU: + if not self.has_tpu: + msg = "TPUs are not available" if not _TPU_AVAILABLE else "you didn't pass `tpu_cores` to `Trainer`" + raise MisconfigurationException(f"You passed `accelerator='tpu'`, but {msg}.") + self._accelerator_type = DeviceType.TPU + elif self.distributed_backend == DeviceType.IPU: + if not self.has_ipu: + msg = "IPUs are not available" if not _IPU_AVAILABLE else "you didn't pass `ipus` to `Trainer`" + raise MisconfigurationException(f"You passed `accelerator='ipu'`, but {msg}.") + self._accelerator_type = DeviceType.IPU + elif self.distributed_backend == DeviceType.GPU: + if not self.has_gpu: + msg = ("you didn't pass `gpus` to `Trainer`" if torch.cuda.is_available() else "GPUs are not available") + raise MisconfigurationException(f"You passed `accelerator='gpu'`, but {msg}.") + self._accelerator_type = DeviceType.GPU + elif self.distributed_backend == DeviceType.CPU: + self._accelerator_type = DeviceType.CPU + + if self.distributed_backend in ["auto"] + list(DeviceType): + self.distributed_backend = None + + def validate_accelerator_type(self) -> None: + if self._accelerator_type and self._accelerator_type != self._device_type: + raise MisconfigurationException( + f"Mismatch between the requested accelerator type ({self._accelerator_type})" + f" and assigned device type ({self._device_type})." + ) + self._accelerator_type = self._device_type + def handle_given_plugins(self) -> None: training_type = None @@ -245,28 +290,49 @@ def cluster_environment(self) -> ClusterEnvironment: return self._cluster_environment @property - def on_cpu(self) -> bool: - return self._device_type == DeviceType.CPU + def has_cpu(self) -> bool: + return True + + @property + def use_cpu(self) -> bool: + return self._accelerator_type == DeviceType.CPU + + @property + def has_gpu(self) -> bool: + # Here, we are not checking for GPU availability, but instead if User has passed + # `gpus` to Trainer for training. + gpus = self.parallel_device_ids + return gpus is not None and len(gpus) > 0 + + @property + def use_gpu(self) -> bool: + return self._accelerator_type == DeviceType.GPU and self.has_gpu @property - def on_tpu(self) -> bool: + def has_tpu(self) -> bool: + # Here, we are not checking for TPU availability, but instead if User has passed + # `tpu_cores` to Trainer for training. return self.tpu_cores is not None @property - def on_ipu(self) -> bool: - return self.ipus is not None or isinstance(self._training_type_plugin, IPUPlugin) + def use_tpu(self) -> bool: + return self._accelerator_type == DeviceType.TPU and self.has_tpu @property def tpu_id(self) -> Optional[int]: - if self.on_tpu and isinstance(self.tpu_cores, list): + if self.use_tpu and isinstance(self.tpu_cores, list): return self.tpu_cores[0] - return None @property - def on_gpu(self) -> bool: - gpus = self.parallel_device_ids - return gpus is not None and len(gpus) > 0 and torch.cuda.is_available() + def has_ipu(self) -> bool: + # Here, we are not checking for IPU availability, but instead if User has passed + # `ipus` to Trainer for training. + return self.ipus is not None or isinstance(self._training_type_plugin, IPUPlugin) + + @property + def use_ipu(self) -> bool: + return self._accelerator_type == DeviceType.IPU and self.has_ipu @property def use_dp(self) -> bool: @@ -308,10 +374,10 @@ def _is_fully_sharded_training_type(self) -> bool: def is_distributed(self) -> bool: # Used for custom plugins. # Custom plugins should implement is_distributed property. - if hasattr(self.training_type_plugin, 'is_distributed') and not self.on_tpu: + if hasattr(self.training_type_plugin, 'is_distributed') and not self.use_tpu: return self.training_type_plugin.is_distributed is_distributed = self.use_ddp or self.use_ddp2 or self.use_horovod - if self.on_tpu: + if self.use_tpu: is_distributed |= self.training_type_plugin.is_distributed return is_distributed @@ -332,14 +398,14 @@ def num_ipus(self) -> int: @property def parallel_devices(self) -> List[Union[torch.device, int]]: - if self.on_gpu: + if self.use_gpu: devices = [torch.device("cuda", i) for i in self.parallel_device_ids] - elif self.on_tpu: + elif self.use_tpu: # explicitly don't make a tpu device here! # https://github.com/PyTorchLightning/pytorch-lightning/issues/3169 if isinstance(self.tpu_cores, int): devices = list(range(self.tpu_cores)) - elif self.on_ipu: + elif self.use_ipu: devices = list(range(self.num_ipus)) else: devices = [torch.device("cpu")] * self.num_processes @@ -373,7 +439,7 @@ def select_precision_plugin(self) -> PrecisionPlugin: # set precision type self.amp_type = AMPType.from_str(self.amp_type) - if self.on_ipu: + if self.use_ipu: return IPUPrecisionPlugin(self.precision) if self._distrib_type == DistributedType.DEEPSPEED or isinstance(self._training_type_plugin, DeepSpeedPlugin): @@ -384,11 +450,11 @@ def select_precision_plugin(self) -> PrecisionPlugin: if self.precision == 64: return DoublePrecisionPlugin() if self.precision == 16: - if self.on_tpu: + if self.use_tpu: return TPUHalfPrecisionPlugin() if self.amp_type == AMPType.NATIVE: - if self.on_cpu: + if self.use_cpu: raise MisconfigurationException( "You have asked for native AMP on CPU, but AMP is only available on GPU." ) @@ -444,8 +510,8 @@ def select_training_type_plugin(self) -> TrainingTypePlugin: use_torchelastic_ddp = self.use_ddp and TorchElasticEnvironment.is_using_torchelastic() use_kubeflow_ddp = self.use_ddp and KubeflowEnvironment.is_using_kubeflow() use_ddp_spawn = self._distrib_type == DistributedType.DDP_SPAWN - use_ddp_cpu_spawn = self.use_ddp and self.on_cpu - use_tpu_spawn = self.on_tpu and self._distrib_type == DistributedType.TPU_SPAWN + use_ddp_cpu_spawn = self.use_ddp and self.use_cpu + use_tpu_spawn = self.use_tpu and self._distrib_type == DistributedType.TPU_SPAWN use_ddp_cpu_torch_elastic = use_ddp_cpu_spawn and TorchElasticEnvironment.is_using_torchelastic() use_ddp_cpu_kubeflow = use_ddp_cpu_spawn and KubeflowEnvironment.is_using_kubeflow() use_ddp_cpu_slurm = use_ddp_cpu_spawn and self.is_slurm_managing_tasks @@ -484,13 +550,13 @@ def select_training_type_plugin(self) -> TrainingTypePlugin: plugin = DataParallelPlugin(parallel_devices=self.parallel_devices) elif self.use_horovod: plugin = HorovodPlugin(parallel_devices=self.parallel_devices) - elif self.on_tpu and isinstance(self.tpu_cores, list): + elif self.use_tpu and isinstance(self.tpu_cores, list): plugin = SingleTPUPlugin(self.tpu_id) - elif self.on_ipu: + elif self.use_ipu: plugin = IPUPlugin(parallel_devices=self.parallel_devices) else: single_gpu_ordinal = device_parser.determine_root_gpu_device(self.parallel_device_ids) - plugin = SingleDevicePlugin(device=torch.device(f"cuda:{single_gpu_ordinal}" if self.on_gpu else "cpu")) + plugin = SingleDevicePlugin(device=torch.device(f"cuda:{single_gpu_ordinal}" if self.use_gpu else "cpu")) return plugin def resolve_training_type_plugin(self, training_type: TrainingTypePlugin) -> TrainingTypePlugin: @@ -526,11 +592,11 @@ def select_accelerator(self) -> Accelerator: ) return self.distributed_backend - if self.on_gpu: + if self.use_gpu: acc_cls = GPUAccelerator - elif self.on_tpu: + elif self.use_tpu: acc_cls = TPUAccelerator - elif self.on_ipu: + elif self.use_ipu: acc_cls = IPUAccelerator else: acc_cls = CPUAccelerator @@ -574,12 +640,15 @@ def set_distributed_mode(self, distributed_backend: Optional[str] = None): if isinstance(self.distributed_backend, Accelerator): return + is_cpu_accelerator_type = self._accelerator_type and self._accelerator_type == DeviceType.CPU + _use_cpu = is_cpu_accelerator_type or self.distributed_backend and 'cpu' in self.distributed_backend + if self.distributed_backend is None: if self.has_horovodrun(): self._set_horovod_backend() elif self.num_gpus == 0 and (self.num_nodes > 1 or self.num_processes > 1): self._distrib_type = DistributedType.DDP - elif self.num_gpus > 1: + elif self.num_gpus > 1 and not _use_cpu: rank_zero_warn( 'You requested multiple GPUs but did not specify a backend, e.g.' ' `Trainer(accelerator="dp"|"ddp"|"ddp2")`. Setting `accelerator="ddp_spawn"` for you.' @@ -598,23 +667,21 @@ def set_distributed_mode(self, distributed_backend: Optional[str] = None): # define the max CPU available self.num_processes = os.cpu_count() # special case with TPUs - elif self.distributed_backend == 'tpu' or self.tpu_cores is not None: + elif self.has_tpu and not _use_cpu: self._device_type = DeviceType.TPU if isinstance(self.tpu_cores, int): self._distrib_type = DistributedType.TPU_SPAWN - elif self.distributed_backend == 'ipu': + elif self.has_ipu and not _use_cpu: self._device_type = DeviceType.IPU elif self.distributed_backend and self._distrib_type is None: self._distrib_type = DistributedType(self.distributed_backend) - # unless you request explicitly for CPU and some GPU are available use them - _on_cpu = self.distributed_backend and 'cpu' in self.distributed_backend - if self.num_gpus > 0 and not _on_cpu: + if self.num_gpus > 0 and not _use_cpu: self._device_type = DeviceType.GPU _gpu_distrib_types = (DistributedType.DP, DistributedType.DDP, DistributedType.DDP_SPAWN, DistributedType.DDP2) # DP and DDP2 cannot run without GPU - if self.num_gpus == 0 and self._distrib_type in _gpu_distrib_types and not _on_cpu: + if self.num_gpus == 0 and self._distrib_type in _gpu_distrib_types and not _use_cpu: rank_zero_warn( 'You requested distributed training on GPUs, but none is available, so we set backend to `ddp_cpu`.' ) @@ -656,7 +723,7 @@ def _set_horovod_backend(self): # Initialize Horovod to get rank / size info hvd.init() - if self.on_gpu: + if self.has_gpu: # Horovod assigns one local GPU per process self.parallel_device_ids = list(range(hvd.local_size())) else: @@ -694,6 +761,12 @@ def has_horovodrun() -> bool: """Returns True if running with `horovodrun` using Gloo or OpenMPI.""" return "OMPI_COMM_WORLD_RANK" in os.environ or "HOROVOD_RANK" in os.environ + def update_device_type_if_ipu_plugin(self) -> None: + # This allows the poptorch.Options that are passed into the IPUPlugin to be the source of truth, + # which gives users the flexibility to not have to pass `ipus` flag directly to Trainer + if isinstance(self._training_type_plugin, IPUPlugin) and self._device_type != DeviceType.IPU: + self._device_type = DeviceType.IPU + def configure_slurm_ddp(self): # extract SLURM flag vars # whenever we have the correct number of tasks, we let slurm manage processes diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 7475cd9c81326..c7267191e91ce 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1213,7 +1213,7 @@ def __setup_profiler(self) -> None: def _log_device_info(self) -> None: rank_zero_info(f'GPU available: {torch.cuda.is_available()}, used: {self._device_type == DeviceType.GPU}') - num_tpu_cores = self.tpu_cores if self.tpu_cores is not None else 0 + num_tpu_cores = self.tpu_cores if self.tpu_cores is not None and self._device_type == DeviceType.TPU else 0 rank_zero_info(f'TPU available: {_TPU_AVAILABLE}, using: {num_tpu_cores} TPU cores') num_ipus = self.ipus if self.ipus is not None else 0 diff --git a/tests/accelerators/test_accelerator_connector.py b/tests/accelerators/test_accelerator_connector.py index 4a9b01281f784..e2f8ca0a4074d 100644 --- a/tests/accelerators/test_accelerator_connector.py +++ b/tests/accelerators/test_accelerator_connector.py @@ -587,3 +587,55 @@ def test_accelerator_choice_multi_node_gpu( gpus=gpus, ) assert isinstance(trainer.training_type_plugin, plugin) + + +@pytest.mark.skipif(torch.cuda.is_available(), reason="test doesn't require GPU") +def test_accelerator_cpu(): + + trainer = Trainer(accelerator="cpu") + + assert trainer._device_type == "cpu" + assert isinstance(trainer.accelerator, CPUAccelerator) + + with pytest.raises(MisconfigurationException, match="You passed `accelerator='gpu'`, but GPUs are not available"): + trainer = Trainer(accelerator="gpu") + + with pytest.raises(MisconfigurationException, match="You requested GPUs:"): + trainer = Trainer(accelerator="cpu", gpus=1) + + +@RunIf(min_gpus=1) +def test_accelerator_gpu(): + + trainer = Trainer(accelerator="gpu", gpus=1) + + assert trainer._device_type == "gpu" + assert isinstance(trainer.accelerator, GPUAccelerator) + + with pytest.raises( + MisconfigurationException, match="You passed `accelerator='gpu'`, but you didn't pass `gpus` to `Trainer`" + ): + trainer = Trainer(accelerator="gpu") + + trainer = Trainer(accelerator="auto", gpus=1) + + assert trainer._device_type == "gpu" + assert isinstance(trainer.accelerator, GPUAccelerator) + + +@RunIf(min_gpus=1) +def test_accelerator_cpu_with_gpus_flag(): + + trainer = Trainer(accelerator="cpu", gpus=1) + + assert trainer._device_type == "cpu" + assert isinstance(trainer.accelerator, CPUAccelerator) + + +@RunIf(min_gpus=2) +def test_accelerator_cpu_with_multiple_gpus(): + + trainer = Trainer(accelerator="cpu", gpus=2) + + assert trainer._device_type == "cpu" + assert isinstance(trainer.accelerator, CPUAccelerator) diff --git a/tests/accelerators/test_ipu.py b/tests/accelerators/test_ipu.py index 78176b76f5606..e4e5ebe1b7827 100644 --- a/tests/accelerators/test_ipu.py +++ b/tests/accelerators/test_ipu.py @@ -19,7 +19,7 @@ import torch.nn.functional as F from pytorch_lightning import Callback, seed_everything, Trainer -from pytorch_lightning.accelerators import IPUAccelerator +from pytorch_lightning.accelerators import CPUAccelerator, IPUAccelerator from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.plugins import IPUPlugin, IPUPrecisionPlugin from pytorch_lightning.trainer.states import RunningStage @@ -491,3 +491,31 @@ def test_precision_plugin(tmpdir): plugin = IPUPrecisionPlugin(precision=16) assert plugin.precision == 16 + + +@RunIf(ipu=True) +def test_accelerator_ipu(): + + trainer = Trainer(accelerator="ipu", ipus=1) + + assert trainer._device_type == "ipu" + assert isinstance(trainer.accelerator, IPUAccelerator) + + with pytest.raises( + MisconfigurationException, match="You passed `accelerator='ipu'`, but you didn't pass `ipus` to `Trainer`" + ): + trainer = Trainer(accelerator="ipu") + + trainer = Trainer(accelerator="auto", ipus=8) + + assert trainer._device_type == "ipu" + assert isinstance(trainer.accelerator, IPUAccelerator) + + +@RunIf(ipu=True) +def test_accelerator_cpu_with_ipus_flag(): + + trainer = Trainer(accelerator="cpu", ipus=1) + + assert trainer._device_type == "cpu" + assert isinstance(trainer.accelerator, CPUAccelerator) diff --git a/tests/accelerators/test_tpu_backend.py b/tests/accelerators/test_tpu_backend.py index b57894816090d..574f97deeafe6 100644 --- a/tests/accelerators/test_tpu_backend.py +++ b/tests/accelerators/test_tpu_backend.py @@ -16,6 +16,9 @@ from torch import nn from pytorch_lightning import Trainer +from pytorch_lightning.accelerators.cpu import CPUAccelerator +from pytorch_lightning.accelerators.tpu import TPUAccelerator +from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.helpers.boring_model import BoringModel from tests.helpers.runif import RunIf from tests.helpers.utils import pl_multi_process_test @@ -113,3 +116,35 @@ def on_post_move_to_device(self): with pytest.warns(UserWarning, match="The model layers do not match"): trainer.fit(model) + + +@RunIf(tpu=True) +def test_accelerator_tpu(): + + trainer = Trainer(accelerator="tpu", tpu_cores=8) + + assert trainer._device_type == "tpu" + assert isinstance(trainer.accelerator, TPUAccelerator) + + with pytest.raises( + MisconfigurationException, match="You passed `accelerator='tpu'`, but you didn't pass `tpu_cores` to `Trainer`" + ): + trainer = Trainer(accelerator="tpu") + + +@RunIf(tpu=True) +def test_accelerator_cpu_with_tpu_cores_flag(): + + trainer = Trainer(accelerator="cpu", tpu_cores=8) + + assert trainer._device_type == "cpu" + assert isinstance(trainer.accelerator, CPUAccelerator) + + +@RunIf(tpu=True) +def test_accelerator_tpu_with_auto(): + + trainer = Trainer(accelerator="auto", tpu_cores=8) + + assert trainer._device_type == "tpu" + assert isinstance(trainer.accelerator, TPUAccelerator)