diff --git a/pl_examples/hpu_examples/simple_mnist/mnist.py b/pl_examples/hpu_examples/simple_mnist/mnist.py new file mode 100644 index 0000000000000..d84f2ac473688 --- /dev/null +++ b/pl_examples/hpu_examples/simple_mnist/mnist.py @@ -0,0 +1,51 @@ +import os +import sys + +import habana_frameworks.torch.core as htcore +import torch +from torch import nn +from torch.nn import functional as F +from torch.utils.data import DataLoader, random_split +from torchvision import transforms +from torchvision.datasets import MNIST + +import pytorch_lightning as pl + + +class MNISTModel(pl.LightningModule): + def __init__(self): + super().__init__() + self.l1 = torch.nn.Linear(28 * 28, 10) + + def forward(self, x): + return torch.relu(self.l1(x.view(x.size(0), -1))) + + def training_step(self, batch, batch_nb): + x, y = batch + loss = F.cross_entropy(self(x), y) + return loss + + def configure_optimizers(self): + return torch.optim.Adam(self.parameters(), lr=0.02) + + +# Init our model +mnist_model = MNISTModel() + +# Init DataLoader from MNIST Dataset +train_ds = MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()) +train_loader = DataLoader(train_ds, batch_size=32) + +# TBD: import these keys from hmp +hmp_keys = ["level", "verbose", "bf16_ops", "fp32_ops"] +hmp_params = dict.fromkeys(hmp_keys) +hmp_params["level"] = "O1" +hmp_params["verbose"] = False +hmp_params["bf16_ops"] = "./pytorch-lightning-fork/pl_examples/hpu_examples/simple_mnist/ops_bf16_mnist.txt" +hmp_params["fp32_ops"] = "./pytorch-lightning-fork/pl_examples/hpu_examples/simple_mnist/ops_fp32_mnist.txt" + +# Initialize a trainer +trainer = pl.Trainer(hpus=1, max_epochs=1, precision=16, hmp_params=hmp_params) + +# Train the model ⚡ +trainer.fit(mnist_model, train_loader) diff --git a/pl_examples/hpu_examples/simple_mnist/ops_bf16_mnist.txt b/pl_examples/hpu_examples/simple_mnist/ops_bf16_mnist.txt new file mode 100644 index 0000000000000..53ec99c15b4ce --- /dev/null +++ b/pl_examples/hpu_examples/simple_mnist/ops_bf16_mnist.txt @@ -0,0 +1,2 @@ +linear +relu diff --git a/pl_examples/hpu_examples/simple_mnist/ops_fp32_mnist.txt b/pl_examples/hpu_examples/simple_mnist/ops_fp32_mnist.txt new file mode 100644 index 0000000000000..4509b7e58ac29 --- /dev/null +++ b/pl_examples/hpu_examples/simple_mnist/ops_fp32_mnist.txt @@ -0,0 +1 @@ +cross_entropy diff --git a/pytorch_lightning/accelerators/__init__.py b/pytorch_lightning/accelerators/__init__.py index 1c9e0024f39bd..27e580fa5b496 100644 --- a/pytorch_lightning/accelerators/__init__.py +++ b/pytorch_lightning/accelerators/__init__.py @@ -13,5 +13,6 @@ from pytorch_lightning.accelerators.accelerator import Accelerator # noqa: F401 from pytorch_lightning.accelerators.cpu import CPUAccelerator # noqa: F401 from pytorch_lightning.accelerators.gpu import GPUAccelerator # noqa: F401 +from pytorch_lightning.accelerators.hpu import HPUAccelerator # noqa: F401 from pytorch_lightning.accelerators.ipu import IPUAccelerator # noqa: F401 from pytorch_lightning.accelerators.tpu import TPUAccelerator # noqa: F401 diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index 0c3f2bf1901ba..e0518bd79b3c0 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -41,6 +41,7 @@ class Accelerator: - GPU - TPU - IPU + - HPU Each Accelerator gets two plugins upon initialization: One to handle differences from the training routine and one to handle different precisions. diff --git a/pytorch_lightning/accelerators/hpu.py b/pytorch_lightning/accelerators/hpu.py new file mode 100644 index 0000000000000..247539c3e200b --- /dev/null +++ b/pytorch_lightning/accelerators/hpu.py @@ -0,0 +1,50 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +import os +from typing import Any + +import torch + +import pytorch_lightning as pl +from pytorch_lightning.accelerators.accelerator import Accelerator +from pytorch_lightning.plugins import DataParallelPlugin +from pytorch_lightning.plugins.precision.hpu_precision import HPUPrecisionPlugin +from pytorch_lightning.plugins.training_type.ddp import DDPPlugin +from pytorch_lightning.plugins.training_type.hpu import HPUPlugin +from pytorch_lightning.utilities.exceptions import MisconfigurationException + +_log = logging.getLogger(__name__) + + +class HPUAccelerator(Accelerator): + """Accelerator for HPU devices.""" + + def setup(self, trainer: "pl.Trainer") -> None: + """ + Raises: + ValueError: + If the precision or training type plugin are unsupported. + """ + if not isinstance(self.precision_plugin, HPUPrecisionPlugin): + # this configuration should have been avoided in the accelerator connector + raise ValueError( + f"The `HPUAccelerator` can only be used with a `HPUPrecisionPlugin`, found: {self.precision_plugin}." + ) + if not isinstance(self.training_type_plugin, (HPUPlugin, DDPPlugin)): + raise ValueError( + "The `HPUAccelerator` can only be used with a `HPUPlugin` or `DDPPlugin," + f" found {self.training_type_plugin}." + ) + return super().setup(trainer) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index c59193859b171..b6f8664d54675 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -251,6 +251,14 @@ def on_gpu(self): """ return self.device.type == "cuda" + @property + def on_hpu(self): + """True if your model is currently running on HPUs. + + Useful to set flags around the LightningModule for different CPU vs GPU vs HPU behavior. + """ + return self.device.type == "hpu" + @property def automatic_optimization(self) -> bool: """If set to ``False`` you are responsible for calling ``.backward()``, ``.step()``, ``.zero_grad()``.""" @@ -1586,6 +1594,7 @@ def optimizer_step( optimizer_idx: int = 0, optimizer_closure: Optional[Callable[[], Any]] = None, on_tpu: bool = False, + on_hpu: bool = None, using_native_amp: bool = False, using_lbfgs: bool = False, ) -> None: @@ -1604,6 +1613,7 @@ def optimizer_step( optimizer_closure: Closure for all optimizers. This closure must be executed as it includes the calls to ``training_step()``, ``optimizer.zero_grad()``, and ``backward()``. on_tpu: ``True`` if TPU backward is required + on_hpu: ``True`` if HPU backward is required using_native_amp: ``True`` if using native amp using_lbfgs: True if the matching optimizer is :class:`torch.optim.LBFGS` diff --git a/pytorch_lightning/overrides/torch_distributed.py b/pytorch_lightning/overrides/torch_distributed.py index 3cbbe5ea760ff..7f2617dd4fdae 100644 --- a/pytorch_lightning/overrides/torch_distributed.py +++ b/pytorch_lightning/overrides/torch_distributed.py @@ -3,6 +3,7 @@ import torch +from pytorch_lightning.utilities import _HPU_AVAILABLE from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_8 log = logging.getLogger(__name__) @@ -53,6 +54,10 @@ def _broadcast_object_list(object_list, src=0, group=None): group_backend = get_backend(group) is_nccl_backend = group_backend == Backend.NCCL + import os + + dist_backend = os.environ.get("PL_TORCH_DISTRIBUTED_BACKEND") + is_hcl_backend = group_backend == torch.distributed.Backend(str(dist_backend)) current_device = torch.device("cpu") if is_nccl_backend: # See note about using torch.cuda.current_device() here in docstring. @@ -60,6 +65,10 @@ def _broadcast_object_list(object_list, src=0, group=None): # true. current_device = torch.device("cuda", torch.cuda.current_device()) object_sizes_tensor = object_sizes_tensor.to(current_device) + elif is_hcl_backend: + current_device = torch.device("hpu") + # Workaround: HPU doesn't not support long tensors for collectives + object_sizes_tensor = object_sizes_tensor.int() object_sizes_tensor = object_sizes_tensor.to(current_device) # Broadcast object sizes @@ -73,6 +82,8 @@ def _broadcast_object_list(object_list, src=0, group=None): if is_nccl_backend: object_tensor = object_tensor.to(current_device) + elif is_hcl_backend: + object_tensor = object_tensor.to(current_device) broadcast(object_tensor, src=src, group=group) @@ -93,7 +104,7 @@ def _broadcast_noop(obj, *_, **__): return obj broadcast_object_list = _broadcast_noop -elif _TORCH_GREATER_EQUAL_1_8: +elif _TORCH_GREATER_EQUAL_1_8 and not _HPU_AVAILABLE: from torch.distributed.distributed_c10d import broadcast_object_list else: broadcast_object_list = _broadcast_object_list diff --git a/pytorch_lightning/plugins/__init__.py b/pytorch_lightning/plugins/__init__.py index 0194591bfc06c..ef32b50c6c00d 100644 --- a/pytorch_lightning/plugins/__init__.py +++ b/pytorch_lightning/plugins/__init__.py @@ -13,6 +13,7 @@ from pytorch_lightning.plugins.precision.deepspeed_precision import DeepSpeedPrecisionPlugin from pytorch_lightning.plugins.precision.double import DoublePrecisionPlugin from pytorch_lightning.plugins.precision.fully_sharded_native_amp import FullyShardedNativeMixedPrecisionPlugin +from pytorch_lightning.plugins.precision.hpu_precision import HPUPrecisionPlugin from pytorch_lightning.plugins.precision.ipu_precision import IPUPrecisionPlugin from pytorch_lightning.plugins.precision.native_amp import NativeMixedPrecisionPlugin from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin @@ -26,6 +27,7 @@ from pytorch_lightning.plugins.training_type.dp import DataParallelPlugin from pytorch_lightning.plugins.training_type.fully_sharded import DDPFullyShardedPlugin from pytorch_lightning.plugins.training_type.horovod import HorovodPlugin +from pytorch_lightning.plugins.training_type.hpu import HPUPlugin from pytorch_lightning.plugins.training_type.ipu import IPUPlugin from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin from pytorch_lightning.plugins.training_type.sharded import DDPShardedPlugin @@ -54,6 +56,8 @@ "HorovodPlugin", "IPUPlugin", "IPUPrecisionPlugin", + "HPUPlugin", + "HPUPrecisionPlugin", "NativeMixedPrecisionPlugin", "PrecisionPlugin", "ShardedNativeMixedPrecisionPlugin", diff --git a/pytorch_lightning/plugins/precision/hpu_precision.py b/pytorch_lightning/plugins/precision/hpu_precision.py new file mode 100644 index 0000000000000..85b618f5f9c26 --- /dev/null +++ b/pytorch_lightning/plugins/precision/hpu_precision.py @@ -0,0 +1,52 @@ +# Copyright (C) 2021 Habana Labs, Ltd. an Intel Company +# All Rights Reserved. +# +# Unauthorized copying of this file or any element(s) within it, via any medium +# is strictly prohibited. +# This file contains Habana Labs, Ltd. proprietary and confidential information +# and is subject to the confidentiality and license agreements under which it +# was provided. +# + +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +from typing import Any, List, Tuple + +import torch.nn as nn +from habana_frameworks.torch.hpex import hmp +from torch.optim import Optimizer + +from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin + + +class HPUPrecisionPlugin(PrecisionPlugin): + """Plugin that enables bfloats/floats on HPUs.""" + + def __init__(self, precision: int, hmp_params: []) -> None: + super().__init__() + self.precision = precision + if hmp_params is not None: + hmp_opt_level = hmp_params["level"] + hmp_bf16 = hmp_params["bf16_ops"] + hmp_fp32 = hmp_params["fp32_ops"] + hmp_verbose = hmp_params["verbose"] + hmp.convert( + opt_level=hmp_opt_level, bf16_file_path=hmp_bf16, fp32_file_path=hmp_fp32, isVerbose=hmp_verbose + ) + + def connect( + self, model: nn.Module, optimizers: List[Optimizer], lr_schedulers: List[Any] + ) -> Tuple[nn.Module, List[Optimizer], List[Any]]: + return super().connect(model=model, optimizers=optimizers, lr_schedulers=lr_schedulers) diff --git a/pytorch_lightning/plugins/training_type/__init__.py b/pytorch_lightning/plugins/training_type/__init__.py index 6a56d68e17db9..f21d0a0383d0a 100644 --- a/pytorch_lightning/plugins/training_type/__init__.py +++ b/pytorch_lightning/plugins/training_type/__init__.py @@ -5,6 +5,7 @@ from pytorch_lightning.plugins.training_type.dp import DataParallelPlugin # noqa: F401 from pytorch_lightning.plugins.training_type.fully_sharded import DDPFullyShardedPlugin # noqa: F401 from pytorch_lightning.plugins.training_type.horovod import HorovodPlugin # noqa: F401 +from pytorch_lightning.plugins.training_type.hpu import HPUPlugin # noqa: F401 from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin # noqa: F401 from pytorch_lightning.plugins.training_type.sharded import DDPShardedPlugin # noqa: F401 from pytorch_lightning.plugins.training_type.sharded_spawn import DDPSpawnShardedPlugin # noqa: F401 diff --git a/pytorch_lightning/plugins/training_type/ddp.py b/pytorch_lightning/plugins/training_type/ddp.py index ea4820f61ec7c..1e685fb21b318 100644 --- a/pytorch_lightning/plugins/training_type/ddp.py +++ b/pytorch_lightning/plugins/training_type/ddp.py @@ -374,7 +374,7 @@ def configure_ddp(self) -> None: self._register_ddp_hooks() def determine_ddp_device_ids(self): - if self.root_device.type == "cpu": + if self.root_device.type == "cpu" or self.root_device.type == "hpu": return None return [self.root_device.index] @@ -534,6 +534,14 @@ def reconciliate_processes(self, trace: str) -> None: shutil.rmtree(sync_dir) raise DeadlockDetectedException(f"DeadLock detected from rank: {self.global_rank} \n {trace}") + def on_save(self, checkpoint: Dict[str, Union[Any, torch.Tensor]]) -> Dict[str, Union[Any, torch.Tensor]]: + if self.root_device.type == "hpu" and self.cluster_environment.global_rank() == 0: + from pytorch_lightning.utilities.apply_func import move_data_to_device + + return move_data_to_device(checkpoint, torch.device("cpu")) + else: + return checkpoint + def teardown(self) -> None: if isinstance(self.model, DistributedDataParallel): self.model = self.lightning_module diff --git a/pytorch_lightning/plugins/training_type/hpu.py b/pytorch_lightning/plugins/training_type/hpu.py new file mode 100644 index 0000000000000..4522cf504f4d6 --- /dev/null +++ b/pytorch_lightning/plugins/training_type/hpu.py @@ -0,0 +1,76 @@ +# Copyright (C) 2021 Habana Labs, Ltd. an Intel Company +# All Rights Reserved. +# +# Unauthorized copying of this file or any element(s) within it, via any medium +# is strictly prohibited. +# This file contains Habana Labs, Ltd. proprietary and confidential information +# and is subject to the confidentiality and license agreements under which it +# was provided. +# + +import os +from typing import Any, Dict, Optional + +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch + +from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO +from pytorch_lightning.plugins.training_type.single_device import SingleDevicePlugin +from pytorch_lightning.utilities import _HPU_AVAILABLE, find_shared_parameters, set_shared_parameters +from pytorch_lightning.utilities.apply_func import move_data_to_device +from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.model_helpers import is_overridden +from pytorch_lightning.utilities.types import _PATH + + +class HPUPlugin(SingleDevicePlugin): + def __init__( + self, + device: int, + checkpoint_io: Optional[CheckpointIO] = None, + debug: bool = False, + ): + + device = torch.device("hpu") + checkpoint_io = checkpoint_io + super().__init__(device, checkpoint_io=checkpoint_io) + + self.debug = debug + + @property + def is_distributed(self) -> bool: + return False + + def setup(self) -> None: + shared_params = find_shared_parameters(self.model) + self.model_to_device() + if is_overridden("on_post_move_to_device", self.lightning_module): + self.model.on_post_move_to_device() + else: + set_shared_parameters(self.model, shared_params) + + def model_to_device(self) -> None: + self.model.to(self.root_device) + + @property + def on_hpu(self) -> bool: + return True + + def pre_dispatch(self) -> None: + if isinstance(self.device, int): + self.device = torch.device(self.device) + + def on_save(self, checkpoint: dict) -> dict: + return move_data_to_device(checkpoint, torch.device("cpu")) diff --git a/pytorch_lightning/plugins/training_type/single_device.py b/pytorch_lightning/plugins/training_type/single_device.py index 1737bf3b41ca8..6e70f7442e25f 100644 --- a/pytorch_lightning/plugins/training_type/single_device.py +++ b/pytorch_lightning/plugins/training_type/single_device.py @@ -38,6 +38,10 @@ def __init__( def on_tpu(self) -> bool: return self.root_device.type == "xla" and _XLA_AVAILABLE + @property + def on_hpu(self) -> bool: + return self.device.type == "hpu" + @property def on_gpu(self) -> bool: return self.root_device.type == "cuda" and torch.cuda.is_available() diff --git a/pytorch_lightning/trainer/connectors/accelerator_connector.py b/pytorch_lightning/trainer/connectors/accelerator_connector.py index 5895c1c6a141e..cdf37cfa48f8b 100644 --- a/pytorch_lightning/trainer/connectors/accelerator_connector.py +++ b/pytorch_lightning/trainer/connectors/accelerator_connector.py @@ -22,6 +22,7 @@ from pytorch_lightning.accelerators.accelerator import Accelerator from pytorch_lightning.accelerators.cpu import CPUAccelerator from pytorch_lightning.accelerators.gpu import GPUAccelerator +from pytorch_lightning.accelerators.hpu import HPUAccelerator from pytorch_lightning.accelerators.ipu import IPUAccelerator from pytorch_lightning.accelerators.tpu import TPUAccelerator from pytorch_lightning.plugins import ( @@ -39,6 +40,8 @@ DoublePrecisionPlugin, FullyShardedNativeMixedPrecisionPlugin, HorovodPlugin, + HPUPlugin, + HPUPrecisionPlugin, IPUPlugin, IPUPrecisionPlugin, NativeMixedPrecisionPlugin, @@ -73,6 +76,7 @@ from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.imports import ( _HOROVOD_AVAILABLE, + _HPU_AVAILABLE, _IPU_AVAILABLE, _TORCH_GREATER_EQUAL_1_7, _TORCH_GREATER_EQUAL_1_8, @@ -96,6 +100,7 @@ def __init__( strategy: Optional[Union[str, TrainingTypePlugin]], gpus, gpu_ids, + hpus, num_nodes, sync_batchnorm, benchmark, @@ -104,6 +109,7 @@ def __init__( precision, amp_type, amp_level, + hmp_params, plugins, ): # initialization @@ -124,6 +130,7 @@ def __init__( self.parallel_device_ids = gpu_ids self.tpu_cores = tpu_cores self.ipus = ipus + self.hpus = hpus self.num_nodes = num_nodes self.sync_batchnorm = sync_batchnorm self.benchmark = benchmark @@ -135,6 +142,7 @@ def __init__( self.precision = precision self.amp_type = amp_type.lower() if isinstance(amp_type, str) else None self.amp_level = amp_level + self.hmp_params = hmp_params self._is_slurm_managing_tasks = False self._precision_plugin: Optional[PrecisionPlugin] = None @@ -209,6 +217,8 @@ def select_accelerator_type(self) -> None: self._accelerator_type = DeviceType.IPU elif self.has_gpu: self._accelerator_type = DeviceType.GPU + elif self.has_hpu: + self._accelerator_type = DeviceType.HPU else: self._set_devices_to_cpu_num_processes() self._accelerator_type = DeviceType.CPU @@ -227,6 +237,11 @@ def select_accelerator_type(self) -> None: msg = "you didn't pass `gpus` to `Trainer`" if torch.cuda.is_available() else "GPUs are not available" raise MisconfigurationException(f"You passed `accelerator='gpu'`, but {msg}.") self._accelerator_type = DeviceType.GPU + elif self.distributed_backend == DeviceType.HPU: + if not self.has_hpu: + msg = "HPUs are not available" if not _HPU_AVAILABLE else "you didn't pass `hpus` to `Trainer`" + raise MisconfigurationException(f"You passed `accelerator='hpu'`, but {msg}.") + self._accelerator_type = DeviceType.HPU elif self.distributed_backend == DeviceType.CPU: self._set_devices_to_cpu_num_processes() self._accelerator_type = DeviceType.CPU @@ -238,7 +253,7 @@ def _validate_accelerator_and_devices(self) -> None: if self.distributed_backend not in self.accelerator_types and self.devices is not None: raise MisconfigurationException( f"You passed `devices={self.devices}` but haven't specified" - " `accelerator=('auto'|'tpu'|'gpu'|'ipu'|'cpu')` for the devices mapping," + " `accelerator=('auto'|'tpu'|'gpu'|'ipu'|'cpu'|'hpu')` for the devices mapping," f" got `accelerator={self.distributed_backend!r}`." ) @@ -264,6 +279,9 @@ def _warn_if_devices_flag_ignored(self) -> None: elif self.distributed_backend in ("auto", DeviceType.GPU): if self.gpus is not None: rank_zero_warn(f"{devices_warning} `gpus={self.gpus}`") + elif self.distributed_backend in ("auto", DeviceType.HPU): + if self.hpus is not None: + rank_zero_warn(f"{devices_warning} `hpus={self.hpus}`") elif self.distributed_backend in ("auto", DeviceType.CPU): if self.num_processes != 1: rank_zero_warn(f"{devices_warning} `num_processes={self.num_processes}`") @@ -277,6 +295,8 @@ def _set_devices_if_none(self) -> None: self.devices = self.ipus elif self._accelerator_type == DeviceType.GPU: self.devices = self.gpus + elif self._accelerator_type == DeviceType.HPU: + self.devices = self.hpus elif self._accelerator_type == DeviceType.CPU: self.devices = self.num_processes @@ -459,6 +479,18 @@ def tpu_id(self) -> Optional[int]: return self.tpu_cores[0] return None + @property + def has_hpu(self) -> bool: + # Here, we are not checking for HPU availability, but instead if User has passed + # `hpus` to Trainer for training. + if self.hpus is not None or isinstance(self._training_type_plugin, HPUPlugin): + return True + return self._map_devices_to_accelerator(DeviceType.HPU) + + @property + def use_hpu(self) -> bool: + return self._accelerator_type == DeviceType.HPU and self.has_hpu + @property def has_ipu(self) -> bool: # Here, we are not checking for IPU availability, but instead if User has passed @@ -488,6 +520,9 @@ def _map_devices_to_accelerator(self, accelerator: str) -> bool: self.devices = IPUAccelerator.auto_device_count() self.ipus = self.devices return True + if accelerator == DeviceType.HPU and _HPU_AVAILABLE: + self.hpus = self.devices + return True if accelerator == DeviceType.GPU and torch.cuda.is_available(): if self.devices == "auto": self.devices = GPUAccelerator.auto_device_count() @@ -568,6 +603,14 @@ def num_ipus(self) -> int: return self._training_type_plugin.replication_factor return 0 + @property + def num_hpus(self) -> int: + if isinstance(self.hpus, int): + return self.hpus + if isinstance(self._training_type_plugin, HPUPlugin): + return self._training_type_plugin.replication_factor + return 0 + @property def parallel_devices(self) -> List[Union[torch.device, int]]: if self.use_gpu: @@ -579,6 +622,8 @@ def parallel_devices(self) -> List[Union[torch.device, int]]: devices = list(range(self.tpu_cores)) elif self.use_ipu: devices = list(range(self.num_ipus)) + elif self.use_hpu: + devices = [torch.device("hpu")] * self.num_processes else: devices = [torch.device("cpu")] * self.num_processes return devices @@ -587,7 +632,7 @@ def parallel_devices(self) -> List[Union[torch.device, int]]: def root_gpu(self) -> Optional[int]: return ( self.accelerator.root_device.index - if not isinstance(self.accelerator, (IPUAccelerator, TPUAccelerator)) + if not isinstance(self.accelerator, (IPUAccelerator, TPUAccelerator, HPUAccelerator)) else None ) @@ -638,6 +683,9 @@ def select_precision_plugin(self) -> PrecisionPlugin: ) return TPUBf16PrecisionPlugin() + if self.use_hpu: + return HPUPrecisionPlugin(self.precision, self.hmp_params) + if self._distrib_type == DistributedType.DEEPSPEED or isinstance(self._training_type_plugin, DeepSpeedPlugin): return DeepSpeedPrecisionPlugin(self.precision) @@ -752,6 +800,8 @@ def select_training_type_plugin(self) -> TrainingTypePlugin: plugin = SingleTPUPlugin(self.tpu_id) elif self.use_ipu: plugin = IPUPlugin(parallel_devices=self.parallel_devices) + elif self.use_hpu: + plugin = HPUPlugin(device=torch.device("hpu")) else: single_gpu_ordinal = device_parser.determine_root_gpu_device(self.parallel_device_ids) plugin = SingleDevicePlugin(device=torch.device(f"cuda:{single_gpu_ordinal}" if self.use_gpu else "cpu")) @@ -796,6 +846,8 @@ def select_accelerator(self) -> Accelerator: acc_cls = TPUAccelerator elif self.use_ipu: acc_cls = IPUAccelerator + elif self.use_hpu: + acc_cls = HPUAccelerator else: acc_cls = CPUAccelerator # as precision_plugin is dependent on training_type_plugin, make sure @@ -841,6 +893,8 @@ def set_distributed_mode(self, strategy: Optional[str] = None): if self.distributed_backend is None: if self.has_horovodrun(): self._set_horovod_backend() + elif self.num_hpus > 1 and not _use_cpu: + self._distrib_type = DistributedType.DDP elif self.num_gpus == 0 and self.num_nodes > 1: self._distrib_type = DistributedType.DDP elif self.num_gpus == 0 and self.num_processes > 1: @@ -878,16 +932,20 @@ def set_distributed_mode(self, strategy: Optional[str] = None): self._distrib_type = DistributedType.TPU_SPAWN elif self.has_ipu and not _use_cpu: self._device_type = DeviceType.IPU + elif self.has_hpu and not _use_cpu: + self._device_type = DeviceType.HPU elif self.distributed_backend and self._distrib_type is None: self._distrib_type = DistributedType(self.distributed_backend) if self.num_gpus > 0 and not _use_cpu: self._device_type = DeviceType.GPU + if self.num_hpus > 0 and not _use_cpu: + self._device_type = DeviceType.HPU + _gpu_distrib_types = (DistributedType.DP, DistributedType.DDP, DistributedType.DDP_SPAWN, DistributedType.DDP2) # DP and DDP2 cannot run without GPU - if self.num_gpus == 0 and self._distrib_type in _gpu_distrib_types and not _use_cpu: - + if self.num_gpus == 0 and self._distrib_type in _gpu_distrib_types and not _use_cpu and not (self.num_hpus > 1): if (self.num_nodes and self.num_nodes > 1) or (self.num_processes and self.num_processes > 1): if self._distrib_type in (DistributedType.DP, DistributedType.DDP2): rank_zero_warn( @@ -911,6 +969,9 @@ def set_distributed_mode(self, strategy: Optional[str] = None): if self._device_type == DeviceType.GPU and self._distrib_type == DistributedType.DDP2: self.num_processes = self.num_nodes + if self._device_type == DeviceType.HPU and self._distrib_type == DistributedType.DDP: + self.num_processes = self.num_hpus + # Horovod is an extra case... if self.distributed_backend == DistributedType.HOROVOD: self._set_horovod_backend() @@ -985,6 +1046,8 @@ def update_device_type_if_training_type_plugin_passed(self) -> None: self._device_type = DeviceType.TPU elif self.use_gpu: self._device_type = DeviceType.GPU + elif self.use_hpu: + self._device_type = DeviceType.HPU else: if self.has_ipu: self._device_type = DeviceType.IPU @@ -992,6 +1055,12 @@ def update_device_type_if_training_type_plugin_passed(self) -> None: self._device_type = DeviceType.TPU elif self.has_gpu: self._device_type = DeviceType.GPU + elif self.has_hpu: + self._device_type = DeviceType.HPU + + def update_device_type_if_hpu_plugin(self) -> None: + if isinstance(self._training_type_plugin, HPUPlugin) and self._device_type != DeviceType.HPU: + self._device_type = DeviceType.HPU @property def is_slurm_managing_tasks(self) -> bool: diff --git a/pytorch_lightning/trainer/data_loading.py b/pytorch_lightning/trainer/data_loading.py index e149aef9a7997..4505d19b254a2 100644 --- a/pytorch_lightning/trainer/data_loading.py +++ b/pytorch_lightning/trainer/data_loading.py @@ -51,6 +51,7 @@ class TrainerDataLoadingMixin(ABC): # the proper values/initialisation should be done in child class val_check_interval: float tpu_local_core_rank: int + hpu_local_core_rank: int train_dataloader: DataLoader num_training_batches: Union[int, float] val_check_batch: float diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 7e5d21e18dc26..4f2cc7e54b766 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -27,7 +27,7 @@ from torch.optim import Optimizer import pytorch_lightning as pl -from pytorch_lightning.accelerators import Accelerator, IPUAccelerator +from pytorch_lightning.accelerators import Accelerator, HPUAccelerator, IPUAccelerator from pytorch_lightning.callbacks import Callback, EarlyStopping, ModelCheckpoint, ProgressBarBase from pytorch_lightning.callbacks.prediction_writer import BasePredictionWriter from pytorch_lightning.core.datamodule import LightningDataModule @@ -65,6 +65,7 @@ from pytorch_lightning.tuner.lr_finder import _LRFinder from pytorch_lightning.tuner.tuning import Tuner from pytorch_lightning.utilities import ( + _HPU_AVAILABLE, _IPU_AVAILABLE, _TPU_AVAILABLE, device_parser, @@ -130,6 +131,7 @@ def __init__( devices: Optional[Union[List[int], str, int]] = None, gpus: Optional[Union[List[int], str, int]] = None, auto_select_gpus: bool = False, + hpus: Optional[int] = None, tpu_cores: Optional[Union[List[int], str, int]] = None, ipus: Optional[int] = None, log_gpu_memory: Optional[str] = None, # TODO: Remove in 1.7 @@ -174,6 +176,7 @@ def __init__( plugins: Optional[Union[PLUGIN_INPUT, List[PLUGIN_INPUT]]] = None, amp_backend: str = "native", amp_level: Optional[str] = None, + hmp_params: ["level", "verbose", "bf16_ops", "fp32_ops"] = None, move_metrics_to_cpu: bool = False, multiple_trainloader_mode: str = "max_size_cycle", stochastic_weight_avg: bool = False, @@ -184,7 +187,7 @@ def __init__( Args: - accelerator: Supports passing different accelerator types ("cpu", "gpu", "tpu", "ipu", "auto") + accelerator: Supports passing different accelerator types ("cpu", "gpu", "tpu", "ipu", "hpu", "auto") as well as custom accelerator instances. .. deprecated:: v1.5 @@ -239,7 +242,7 @@ def __init__( deterministic: If ``True``, sets whether PyTorch operations must use deterministic algorithms. Default: ``False``. - devices: Will be mapped to either `gpus`, `tpu_cores`, `num_processes` or `ipus`, + devices: Will be mapped to either `gpus`, `tpu_cores`, `hpus`, `num_processes` or `ipus`, based on the accelerator type. fast_dev_run: Runs n if set to ``n`` (int) else 1 if set to ``True`` batch(es) @@ -381,6 +384,10 @@ def __init__( ipus: How many IPUs to train on. + hpus: How many HPUs to train on. + + hmp_params: list of habana mixed precision parameters + track_grad_norm: -1 no tracking. Otherwise tracks that p-norm. May be set to 'inf' infinity-norm. If using Automatic Mixed Precision (AMP), the gradients will be unscaled before logging them. @@ -437,6 +444,7 @@ def __init__( strategy, gpus, gpu_ids, + hpus, num_nodes, sync_batchnorm, benchmark, @@ -445,6 +453,7 @@ def __init__( precision, amp_backend, amp_level, + hmp_params, plugins, ) self.logger_connector = LoggerConnector(self, log_gpu_memory) @@ -826,6 +835,12 @@ def _validate_impl( # -------------------- Trainer._log_api_event("validate") self.verbose_evaluate = verbose + # log hyper-parameters + if self.logger is not None: + # save exp to get started (this is where the first experiment logs are written) + # self.logger.log_hyperparams(self.lightning_module.hparams_initial) + self.logger.log_graph(self.lightning_module) + self.logger.save() self.state.fn = TrainerFn.VALIDATING self.state.status = TrainerStatus.RUNNING @@ -1562,6 +1577,9 @@ def _log_device_info(self) -> None: num_ipus = self.ipus if self.ipus is not None else 0 rank_zero_info(f"IPU available: {_IPU_AVAILABLE}, using: {num_ipus} IPUs") + num_hpus = self.hpus if self.hpus is not None else 0 + rank_zero_info(f"HPU available: {_HPU_AVAILABLE}, using: {num_hpus} HPUs") + if torch.cuda.is_available() and self._device_type != DeviceType.GPU: rank_zero_warn( "GPU available but not used. Set the gpus flag in your trainer `Trainer(gpus=1)` or script `--gpus=1`." @@ -1579,6 +1597,12 @@ def _log_device_info(self) -> None: " `Trainer(ipus=8)` or script `--ipus=8`." ) + if _HPU_AVAILABLE and self._device_type != DeviceType.HPU and not isinstance(self.accelerator, HPUAccelerator): + rank_zero_warn( + "HPU available but not used. Set the `hpus` flag in your trainer" + " `Trainer(hpus=8)` or script `--hpus=8`." + ) + def _on_exception(self): if not _fault_tolerant_training(): return @@ -1657,6 +1681,10 @@ def ipus(self) -> int: def num_gpus(self) -> int: return self._accelerator_connector.num_gpus + @property + def hpus(self) -> int: + return self._accelerator_connector.num_hpus + @property def devices(self) -> Optional[Union[List[int], str, int]]: return self._accelerator_connector.devices diff --git a/pytorch_lightning/utilities/__init__.py b/pytorch_lightning/utilities/__init__.py index 158d7356c91ce..91e8d90e762e9 100644 --- a/pytorch_lightning/utilities/__init__.py +++ b/pytorch_lightning/utilities/__init__.py @@ -34,6 +34,7 @@ _FAIRSCALE_OSS_FP16_BROADCAST_AVAILABLE, _GROUP_AVAILABLE, _HOROVOD_AVAILABLE, + _HPU_AVAILABLE, _HYDRA_AVAILABLE, _HYDRA_EXPERIMENTAL_AVAILABLE, _IPU_AVAILABLE, diff --git a/pytorch_lightning/utilities/argparse.py b/pytorch_lightning/utilities/argparse.py index 61443bea07cd7..ef4db526bf70f 100644 --- a/pytorch_lightning/utilities/argparse.py +++ b/pytorch_lightning/utilities/argparse.py @@ -242,7 +242,7 @@ def add_argparse_args( else: use_type = arg_types[0] - if arg == "gpus" or arg == "tpu_cores": + if arg == "gpus" or arg == "tpu_cores" or arg == "hpus": use_type = _gpus_allowed_type # hack for types in (int, float) diff --git a/pytorch_lightning/utilities/distributed.py b/pytorch_lightning/utilities/distributed.py index 47fa7b791eae0..92bb044e1614d 100644 --- a/pytorch_lightning/utilities/distributed.py +++ b/pytorch_lightning/utilities/distributed.py @@ -22,7 +22,12 @@ from torch.nn.parallel.distributed import DistributedDataParallel import pytorch_lightning as pl -from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_8, _TORCH_GREATER_EQUAL_1_9, _TPU_AVAILABLE +from pytorch_lightning.utilities.imports import ( + _HPU_AVAILABLE, + _TORCH_GREATER_EQUAL_1_8, + _TORCH_GREATER_EQUAL_1_9, + _TPU_AVAILABLE, +) if _TPU_AVAILABLE: import torch_xla.core.xla_model as xm @@ -208,6 +213,9 @@ def forward( gathered_tensor = [torch.zeros_like(tensor) for _ in range(torch.distributed.get_world_size())] + if _HPU_AVAILABLE: + # HPU distributed backend doesn't support int64 tensors + tensor = tensor.int() torch.distributed.all_gather(gathered_tensor, tensor, group=group) gathered_tensor = torch.stack(gathered_tensor, dim=0) @@ -381,6 +389,18 @@ def init_dist_connection( world_size = world_size if world_size is not None else cluster_environment.world_size() os.environ["MASTER_ADDR"] = cluster_environment.master_address() os.environ["MASTER_PORT"] = str(cluster_environment.master_port()) + + # local rank mapping for device open is needed for hpu devices + if torch_distributed_backend == "hcl" or torch_distributed_backend == "hccl": + try: + import habana_frameworks.torch.core.hccl + except Exception: + print("hccl backend is not supported, using hcl backend") + torch_distributed_backend = "hcl" + os.environ["PL_TORCH_DISTRIBUTED_BACKEND"] = "hcl" + + os.environ["ID"] = str(cluster_environment.local_rank()) + if torch.distributed.is_available() and not torch.distributed.is_initialized(): log.info(f"initializing distributed: GLOBAL_RANK: {global_rank}, MEMBER: {global_rank + 1}/{world_size}") torch.distributed.init_process_group( diff --git a/pytorch_lightning/utilities/enums.py b/pytorch_lightning/utilities/enums.py index 436c675c382c2..3abde959d950a 100644 --- a/pytorch_lightning/utilities/enums.py +++ b/pytorch_lightning/utilities/enums.py @@ -128,6 +128,7 @@ class DeviceType(LightningEnum): GPU = "GPU" IPU = "IPU" TPU = "TPU" + HPU = "HPU" class GradClipAlgorithmType(LightningEnum): diff --git a/pytorch_lightning/utilities/imports.py b/pytorch_lightning/utilities/imports.py index edf5f75aee6a9..da7abc8101355 100644 --- a/pytorch_lightning/utilities/imports.py +++ b/pytorch_lightning/utilities/imports.py @@ -109,6 +109,9 @@ def _compare_version(package: str, op: Callable, version: str, use_base_version: else: _IPU_AVAILABLE = False +from habana_frameworks.torch.utils.library_loader import is_habana_avaialble + +_HPU_AVAILABLE = is_habana_avaialble() # experimental feature within PyTorch Lightning. def _fault_tolerant_training() -> bool: diff --git a/tests/accelerators/test_hpu.py b/tests/accelerators/test_hpu.py new file mode 100644 index 0000000000000..f01bb0992a775 --- /dev/null +++ b/tests/accelerators/test_hpu.py @@ -0,0 +1,187 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +from typing import Optional + +import pytest +import torch +import torch.nn.functional as F + +from pytorch_lightning import Callback, seed_everything, Trainer +from pytorch_lightning.accelerators import CPUAccelerator, HPUAccelerator +from pytorch_lightning.core.lightning import LightningModule +from pytorch_lightning.plugins import HPUPlugin, HPUPrecisionPlugin +from pytorch_lightning.trainer.states import RunningStage, TrainerFn +from pytorch_lightning.trainer.supporters import CombinedLoader +from pytorch_lightning.utilities import _HPU_AVAILABLE, DeviceType +from pytorch_lightning.utilities.exceptions import MisconfigurationException +from tests.helpers.boring_model import BoringModel +from tests.helpers.datamodules import ClassifDataModule +from tests.helpers.runif import RunIf +from tests.helpers.simple_models import ClassificationModel + + +class HPUModel(BoringModel): + def training_step(self, batch, batch_idx): + output = self(batch) + loss = self.loss(batch, output) + return loss + + def validation_step(self, batch, batch_idx): + output = self(batch) + loss = self.loss(batch, output) + return loss + + def test_step(self, batch, batch_idx): + output = self(batch) + loss = self.loss(batch, output) + return loss + + def training_epoch_end(self, outputs) -> None: + pass + + def validation_epoch_end(self, outputs) -> None: + pass + + def test_epoch_end(self, outputs) -> None: + pass + + +@pytest.mark.skipif(_HPU_AVAILABLE, reason="test requires non-HPU machine") +def test_fail_if_no_hpus(tmpdir): + with pytest.raises(MisconfigurationException, match="HPU Accelerator requires HPU devices to run"): + Trainer(default_root_dir=tmpdir, hpus=1) + + with pytest.raises(MisconfigurationException, match="HPU Accelerator requires HPU devices to run"): + Trainer(default_root_dir=tmpdir, hpus=1, accelerator="hpu") + + +@RunIf(hpu=True) +def test_accelerator_selected(tmpdir): + trainer = Trainer(default_root_dir=tmpdir, hpus=1) + assert isinstance(trainer.accelerator, HPUAccelerator) + trainer = Trainer(default_root_dir=tmpdir, hpus=1, accelerator="hpu") + assert isinstance(trainer.accelerator, HPUAccelerator) + + +@RunIf(hpu=True) +def test_warning_if_hpus_not_used(tmpdir): + with pytest.warns(UserWarning, match="HPU available but not used. Set the `hpus` flag in your trainer"): + Trainer(default_root_dir=tmpdir) + + +## TBD +@pytest.mark.skipif(_HPU_AVAILABLE, reason="PyTorch is not linked with support for hpu devices") +@RunIf(hpu=True) +@pytest.mark.parametrize("hpus", [1, 8]) +def test_all_stages(tmpdir, hpus): + if hpus > 1: + os.environ["PL_TORCH_DISTRIBUTED_BACKEND"] = "hcl" + model = HPUModel() + trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, hpus=hpus) + trainer.fit(model) + trainer.validate(model) + trainer.test(model) + trainer.predict(model) + + +@RunIf(hpu=True) +def test_mixed_precision(tmpdir): + class TestCallback(Callback): + def setup(self, trainer: Trainer, pl_module: LightningModule, stage: Optional[str] = None) -> None: + assert trainer.accelerator.model.precision == 16 + raise SystemExit + + hmp_keys = ["level", "verbose", "bf16_ops", "fp32_ops"] + hmp_params = dict.fromkeys(hmp_keys) + hmp_params["level"] = "O1" + hmp_params["verbose"] = False + hmp_params["bf16_ops"] = "./pytorch-lightning-fork/pl_examples/hpu_examples/simple_mnist/ops_bf16_mnist.txt" + hmp_params["fp32_ops"] = "./pytorch-lightning-fork/pl_examples/hpu_examples/simple_mnist/ops_fp32_mnist.txt" + + model = HPUModel() + trainer = Trainer( + default_root_dir=tmpdir, + fast_dev_run=True, + hpus=1, + precision=16, + hmp_params=hmp_params, + callbacks=TestCallback(), + ) + assert isinstance(trainer.accelerator.precision_plugin, HPUPrecisionPlugin) + assert trainer.accelerator.precision_plugin.precision == 16 + with pytest.raises(SystemExit): + trainer.fit(model) + + +@RunIf(hpu=True) +def test_precision_plugin(tmpdir): + """Ensure precision plugin value is set correctly.""" + + hmp_keys = ["level", "verbose", "bf16_ops", "fp32_ops"] + hmp_params = dict.fromkeys(hmp_keys) + hmp_params["level"] = "O1" + hmp_params["verbose"] = False + hmp_params["bf16_ops"] = "./pytorch-lightning-fork/pl_examples/hpu_examples/simple_mnist/ops_bf16_mnist.txt" + hmp_params["fp32_ops"] = "./pytorch-lightning-fork/pl_examples/hpu_examples/simple_mnist/ops_fp32_mnist.txt" + + plugin = HPUPrecisionPlugin(precision=16, hmp_params=hmp_params) + assert plugin.precision == 16 + + +@RunIf(hpu=True) +def test_accelerator_hpu(): + + trainer = Trainer(accelerator="hpu", hpus=1) + + assert trainer._device_type == "hpu" + assert isinstance(trainer.accelerator, HPUAccelerator) + + with pytest.raises( + MisconfigurationException, match="You passed `accelerator='hpu'`, but you didn't pass `hpus` to `Trainer`" + ): + trainer = Trainer(accelerator="hpu") + + trainer = Trainer(accelerator="auto", hpus=8) + + assert trainer._device_type == "hpu" + assert isinstance(trainer.accelerator, HPUAccelerator) + + +@RunIf(hpu=True) +def test_accelerator_cpu_with_hpus_flag(): + + trainer = Trainer(accelerator="cpu", hpus=1) + + assert trainer._device_type == "cpu" + assert isinstance(trainer.accelerator, CPUAccelerator) + + +@RunIf(hpu=True) +def test_accelerator_hpu_with_devices(): + + trainer = Trainer(accelerator="hpu", devices=1) + + assert trainer.hpus == 1 + assert isinstance(trainer.training_type_plugin, HPUPlugin) + assert isinstance(trainer.accelerator, HPUAccelerator) + + +@RunIf(hpu=True) +def test_accelerator_auto_with_devices_hpu(): + + trainer = Trainer(accelerator="auto", devices=1) + + assert trainer._device_type == "hpu" + assert trainer.hpus == 1