diff --git a/.azure-pipelines/hpu-tests.yml b/.azure-pipelines/hpu-tests.yml
index 1be1677045a6c..d846994175f40 100644
--- a/.azure-pipelines/hpu-tests.yml
+++ b/.azure-pipelines/hpu-tests.yml
@@ -31,3 +31,23 @@ jobs:
apt-get install -y hwinfo
hwinfo --short
displayName: 'Instance HW info'
+
+ - bash: |
+ pip install . --requirement requirements/test.txt
+ displayName: 'Install dependencies'
+
+ - bash: |
+ python ".azure-pipelines/run_hpu_tests.py"
+ displayName: 'HPU Tests in parallel'
+
+ - bash: |
+ export PYTHONPATH="${PYTHONPATH}:$(pwd)"
+ python "pl_examples/hpu_examples/simple_mnist/mnist.py"
+ displayName: 'Testing: HPU examples'
+
+ - task: PublishTestResults@2
+ inputs:
+ testResultsFiles: 'hpu*_test-results.xml'
+ testRunTitle: '$(Agent.OS) - $(Build.DefinitionName) - Python $(python.version)'
+ condition: succeededOrFailed()
+ displayName: 'Publish test results'
diff --git a/.azure-pipelines/run_hpu_tests.py b/.azure-pipelines/run_hpu_tests.py
new file mode 100644
index 0000000000000..590c5d9c42251
--- /dev/null
+++ b/.azure-pipelines/run_hpu_tests.py
@@ -0,0 +1,148 @@
+"""This file is called from the hpu-tests.yml pipeline.
+
+The following script run the hpu tests in parallel.
+Tests run are:
+1. test_inference_only is run on four cards
+2. test_all_stages on two cards
+3. complete hpu tests using one card
+4. complete hpu tests using eight cards.
+"""
+import itertools
+import subprocess
+import sys
+
+HPU_TESTS_DICTIONARY = {
+ "hpu1_test": "python -m coverage run --source pytorch_lightning -m pytest -sv tests/accelerators/test_hpu.py \
+ --forked \
+ --junitxml=hpu1_test-results.xml",
+ "hpu2_test": "python -m coverage run --source pytorch_lightning -m pytest -sv tests/accelerators/test_hpu.py \
+ -k test_all_stages \
+ --hpus 2 \
+ --verbose \
+ --capture=no \
+ --forked \
+ --junitxml=hpu2_test-results.xml",
+ "hpu4_test": "python -m coverage run --source pytorch_lightning -m pytest -sv tests/accelerators/test_hpu.py \
+ -k test_inference_only \
+ --hpus 4 \
+ --capture=no \
+ --verbose \
+ --forked \
+ --junitxml=hpu4_test-results.xml",
+ "hpu8_test": "python -m coverage run --source pytorch_lightning -m pytest -sv tests/accelerators/test_hpu.py \
+ --forked \
+ --hpus 8 \
+ --junitxml=hpu8_test-results.xml",
+ "hpu1_precision_test": "python -m coverage run --source pytorch_lightning -m pytest -sv tests/plugins/precision/hpu/test_hpu.py \
+ --hmp-bf16 'tests/plugins/precision/hpu/ops_bf16.txt' \
+ --hmp-fp32 'tests/plugins/precision/hpu/ops_fp32.txt' \
+ --forked \
+ --junitxml=hpu1_precision_test-results.xml",
+}
+
+HPU1_TEST = HPU_TESTS_DICTIONARY["hpu1_test"]
+HPU2_TEST = HPU_TESTS_DICTIONARY["hpu2_test"]
+HPU4_TEST = HPU_TESTS_DICTIONARY["hpu4_test"]
+HPU8_TEST = HPU_TESTS_DICTIONARY["hpu8_test"]
+HPU1_PRECISION_TEST = HPU_TESTS_DICTIONARY["hpu1_precision_test"]
+
+PARALLEL_HPU_TESTS_EXECUTION = [[HPU4_TEST, HPU1_TEST], [HPU2_TEST, HPU1_TEST], [HPU8_TEST], [HPU1_PRECISION_TEST]]
+TIMEOUT = 60 # seconds
+TIMEOUT_EXIT_CODE = -9
+
+
+def run_hpu_tests_parallel(timeout=TIMEOUT):
+ """This function is called to run the HPU tests in parallel.
+
+ We run the tests in sub process to utilize all the eight cards available in the DL1 instance
+ Considering the max time taken to run the HPU tests as 60 seconds, we kill the process if the time taken exceeds.
+
+ Args:
+ timeout: The threshold time to run the HPU tests in parallel.
+ An exception is logged if the threshold timeout gets expired.
+ TIMEOUT_EXIT_CODE will be returned as -9 in case of timeout,
+ 0 in case of success and 4 in case of failure.
+
+ Return:
+ The list of exit status of the HPU tests that were run in the subprocess.
+ Here, the exit_status 0 means the test run is successful. exit_status 1 means the test run is failed.
+ """
+ exit_status = []
+ with open("stdout_log.txt", "w") as stdout_log, open("error_log.txt", "w") as error_log:
+ for hpu_tests in PARALLEL_HPU_TESTS_EXECUTION:
+ process_list = [
+ subprocess.Popen(
+ each_hpu_test, shell=True, stdout=stdout_log, stderr=error_log, universal_newlines=True
+ )
+ for each_hpu_test in hpu_tests
+ ]
+ for process in process_list:
+ try:
+ exit_status.append(process.wait(timeout=TIMEOUT))
+ except subprocess.TimeoutExpired as e:
+ print(e)
+ print("Killing the process....")
+ process.kill()
+ exit_status.append(TIMEOUT_EXIT_CODE)
+ return exit_status
+
+
+def zip_cmd_exitcode(exit_status):
+ """This function is called to zip the tests that were executed with the exit status of the test.
+
+ Args:
+ exit_status: The returned exit_status after executing run_hpu_tests_parallel().
+
+ Return:
+ A list of hpu tests called and their exit status.
+ """
+ status_list = []
+ status_list = list(zip(list(itertools.chain(*PARALLEL_HPU_TESTS_EXECUTION)), exit_status))
+ return status_list
+
+
+def print_logs(filename):
+ """This function is called to read the file and print the logs.
+
+ Args:
+ filename: Provide the log filename that need to be print on the console.
+ """
+ with open(filename) as f:
+ print(f.read())
+
+
+def print_subprocess_logs_and_return_status(exit_status):
+ """This function is called to print the logs of subprocess stdout and stderror and return the status of test
+ execution.
+
+ Args:
+ exit_status: The returned exit_status after executing run_hpu_tests_parallel().
+
+ Return:
+ Based on the exit status of the HPU tests, we return success or failure to the main method.
+ """
+ if all(v == 0 for v in exit_status):
+ print("All HPU tests passed")
+ file_name = "stdout_log.txt"
+ print_logs(file_name)
+ return 0
+ else:
+ print("HPU tests are failing")
+ print("Printing stdout_log.txt...")
+ file_name = "stdout_log.txt"
+ print_logs(file_name)
+ print("Printing error_log.txt...")
+ file_name = "error_log.txt"
+ print_logs(file_name)
+ return 1
+
+
+def main():
+ exit_status = run_hpu_tests_parallel(timeout=TIMEOUT)
+ status_list = zip_cmd_exitcode(exit_status)
+ print("HPU Tests executed and their exit status:", status_list)
+ return print_subprocess_logs_and_return_status(exit_status)
+
+
+if __name__ == "__main__":
+ sys.exit(main())
diff --git a/CHANGELOG.md b/CHANGELOG.md
index 2ab43e60a26a4..76e869240a87b 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -167,6 +167,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added `AcceleratorRegistry` ([#12180](https://github.com/PyTorchLightning/pytorch-lightning/pull/12180))
+- Added support for Habana Accelerator (HPU) ([#11808](https://github.com/PyTorchLightning/pytorch-lightning/pull/11808))
+
+
### Changed
- Drop PyTorch 1.7 support ([#12191](https://github.com/PyTorchLightning/pytorch-lightning/pull/12191))
diff --git a/docs/source/accelerators/hpu.rst b/docs/source/accelerators/hpu.rst
new file mode 100644
index 0000000000000..fd7bd310ffc43
--- /dev/null
+++ b/docs/source/accelerators/hpu.rst
@@ -0,0 +1,124 @@
+.. _hpu:
+
+Habana Gaudi AI Processor (HPU)
+===============================
+
+Lightning supports `Habana Gaudi AI Processor (HPU) `__, for accelerating Deep Learning training workloads.
+
+HPU Terminology
+---------------
+
+Habana® Gaudi® AI training processors are built on a heterogeneous architecture with a cluster of fully programmable Tensor Processing Cores (TPC) along with its associated development tools and libraries, and a configurable Matrix Math engine.
+
+The TPC core is a VLIW SIMD processor with an instruction set and hardware tailored to serve training workloads efficiently.
+The Gaudi memory architecture includes on-die SRAM and local memories in each TPC and,
+Gaudi is the first DL training processor that has integrated RDMA over Converged Ethernet (RoCE v2) engines on-chip.
+
+On the software side, the PyTorch Habana bridge interfaces between the framework and SynapseAI software stack to enable the execution of deep learning models on the Habana Gaudi device.
+
+Gaudi offers a substantial price/performance advantage -- so you get to do more deep learning training while spending less.
+
+For more information, check out `Gaudi Architecture `__ and `Gaudi Developer Docs `__.
+
+How to access HPUs
+------------------
+
+To use HPUs, you must have access to a system with HPU devices.
+You can either use `Gaudi-based AWS EC2 DL1 instances `__ or `Supermicro X12 Gaudi server `__ to get access to HPUs.
+
+Check out the `Getting Started Guide with AWS and Habana `__.
+
+Training with HPUs
+------------------
+
+To enable PyTorch Lightning to utilize the HPU accelerator, simply provide ``accelerator="hpu"`` parameter to the Trainer class.
+
+.. code-block:: python
+
+ trainer = Trainer(accelerator="hpu")
+
+Passing ``devices=1`` and ``accelerator="hpu"`` to the Trainer class enables the Habana accelerator for single Gaudi training.
+
+.. code-block:: python
+
+ trainer = Trainer(devices=1, accelerator="hpu")
+
+The ``devices=8`` and ``accelerator="hpu"`` parameters to the Trainer class enables the Habana accelerator for distributed training with 8 Gaudis.
+It uses :class:`~pytorch_lightning.strategies.hpu_parallel.HPUParallelStrategy` internally which is based on DDP strategy with the addition of Habana's collective communication library (HCCL) to support scale-up within a node and scale-out across multiple nodes.
+
+.. code-block:: python
+
+ trainer = Trainer(devices=8, accelerator="hpu")
+
+.. note::
+ If the ``devices`` flag is not defined, it will assume ``devices`` to be ``"auto"`` and select 8 Gaudi devices for :class:`~pytorch_lightning.accelerators.hpu.HPUAccelerator`.
+
+
+Mixed Precision Plugin
+----------------------
+
+Lightning also allows mixed precision training with HPUs.
+By default, HPU training will use 32-bit precision. To enable mixed precision, set the ``precision`` flag.
+
+.. code-block:: python
+
+ trainer = Trainer(devices=1, accelerator="hpu", precision=16)
+
+
+Enabling Mixed Precision Options
+--------------------------------
+
+Internally, :class:`~pytorch_lightning.plugins.precision.hpu.HPUPrecisionPlugin` uses the Habana Mixed Precision (HMP) package to enable mixed precision training.
+
+You can execute the ops in FP32 or BF16 precision. The HMP package modifies the Python operators to add the appropriate cast operations for the arguments before execution.
+The default settings enable users to enable mixed precision training with minimal code easily.
+
+In addition to the default settings in HMP, users also have the option of overriding these defaults and providing their
+BF16 and FP32 operator lists by passing them as parameter to :class:`~pytorch_lightning.plugins.precision.hpu.HPUPrecisionPlugin`.
+
+The below snippet shows an example model using MNIST with a single Habana Gaudi device and making use of HMP by overriding the default parameters.
+This enables advanced users to provide their own BF16 and FP32 operator list instead of using the HMP defaults.
+
+.. code-block:: python
+
+ import pytorch_lightning as pl
+ from pytorch_lightning.plugins import HPUPrecisionPlugin
+
+ # Initialize a trainer with HPU accelerator for HPU strategy for single device,
+ # with mixed precision using overidden HMP settings
+ trainer = pl.Trainer(
+ accelerator="hpu",
+ devices=1,
+ # Optional Habana mixed precision params to be set
+ # Checkout `pl_examples/hpu_examples/simple_mnist/ops_bf16_mnist.txt` for the format
+ plugins=[
+ HPUPrecisionPlugin(
+ precision=16,
+ opt_level="O1",
+ verbose=False,
+ bf16_file_path="ops_bf16_mnist.txt",
+ fp32_file_path="ops_fp32_mnist.txt",
+ )
+ ],
+ )
+
+ # Init our model
+ model = LitClassifier()
+ # Init the data
+ dm = MNISTDataModule(batch_size=batch_size)
+
+ # Train the model ⚡
+ trainer.fit(model, datamodule=dm)
+
+For more details, please refer to `PyTorch Mixed Precision Training on Gaudi `__.
+
+----------------
+
+.. _known-limitations_hpu:
+
+Known limitations
+-----------------
+
+* Multiple optimizers are not supported.
+* `Habana dataloader `__ is not supported.
+* :class:`~pytorch_lightning.callbacks.device_stats_monitor.DeviceStatsMonitor` is not supported.
diff --git a/docs/source/api_references.rst b/docs/source/api_references.rst
index b0bcf4e74ff8b..23a43990fed39 100644
--- a/docs/source/api_references.rst
+++ b/docs/source/api_references.rst
@@ -16,6 +16,7 @@ Accelerator API
Accelerator
CPUAccelerator
GPUAccelerator
+ HPUAccelerator
IPUAccelerator
TPUAccelerator
@@ -59,9 +60,11 @@ Strategy API
DataParallelStrategy
DeepSpeedStrategy
HorovodStrategy
+ HPUParallelStrategy
IPUStrategy
ParallelStrategy
SingleDeviceStrategy
+ SingleHPUStrategy
SingleTPUStrategy
Strategy
TPUSpawnStrategy
@@ -198,6 +201,7 @@ Precision Plugins
DeepSpeedPrecisionPlugin
DoublePrecisionPlugin
FullyShardedNativeMixedPrecisionPlugin
+ HPUPrecisionPlugin
IPUPrecisionPlugin
MixedPrecisionPlugin
NativeMixedPrecisionPlugin
@@ -234,6 +238,7 @@ Checkpoint IO Plugins
:template: classtemplate.rst
CheckpointIO
+ HPUCheckpointIO
TorchCheckpointIO
XLACheckpointIO
diff --git a/docs/source/extensions/accelerator.rst b/docs/source/extensions/accelerator.rst
index 762f2b6e57a90..0d78371a0e7ed 100644
--- a/docs/source/extensions/accelerator.rst
+++ b/docs/source/extensions/accelerator.rst
@@ -15,6 +15,7 @@ Currently there are accelerators for:
- GPU
- TPU
- IPU
+- HPU
Each Accelerator gets two plugins upon initialization:
One to handle differences from the training routine and one to handle different precisions.
@@ -58,5 +59,6 @@ Accelerator API
Accelerator
CPUAccelerator
GPUAccelerator
- TPUAccelerator
+ HPUAccelerator
IPUAccelerator
+ TPUAccelerator
diff --git a/docs/source/extensions/plugins.rst b/docs/source/extensions/plugins.rst
index 252fb47570fc4..3bfa7ad24b29c 100644
--- a/docs/source/extensions/plugins.rst
+++ b/docs/source/extensions/plugins.rst
@@ -61,17 +61,18 @@ Precision Plugins
:nosignatures:
:template: classtemplate.rst
- PrecisionPlugin
- MixedPrecisionPlugin
- NativeMixedPrecisionPlugin
- ShardedNativeMixedPrecisionPlugin
ApexMixedPrecisionPlugin
DeepSpeedPrecisionPlugin
- TPUPrecisionPlugin
- TPUBf16PrecisionPlugin
DoublePrecisionPlugin
FullyShardedNativeMixedPrecisionPlugin
+ HPUPrecisionPlugin
IPUPrecisionPlugin
+ MixedPrecisionPlugin
+ NativeMixedPrecisionPlugin
+ PrecisionPlugin
+ ShardedNativeMixedPrecisionPlugin
+ TPUBf16PrecisionPlugin
+ TPUPrecisionPlugin
Cluster Environments
diff --git a/docs/source/extensions/strategy.rst b/docs/source/extensions/strategy.rst
index e85b719e8566c..7c5596c7362ea 100644
--- a/docs/source/extensions/strategy.rst
+++ b/docs/source/extensions/strategy.rst
@@ -108,9 +108,11 @@ Built-In Training Strategies
DataParallelStrategy
DeepSpeedStrategy
HorovodStrategy
+ HPUParallelStrategy
IPUStrategy
ParallelStrategy
SingleDeviceStrategy
+ SingleHPUStrategy
SingleTPUStrategy
Strategy
TPUSpawnStrategy
diff --git a/docs/source/index.rst b/docs/source/index.rst
index 4e7543b51f4a7..d02eff4fb225a 100644
--- a/docs/source/index.rst
+++ b/docs/source/index.rst
@@ -88,6 +88,7 @@ Welcome to PyTorch Lightning
accelerators/gpu
accelerators/tpu
accelerators/ipu
+ accelerators/hpu
.. toctree::
:maxdepth: 1
diff --git a/pl_examples/hpu_examples/simple_mnist/mnist.py b/pl_examples/hpu_examples/simple_mnist/mnist.py
new file mode 100644
index 0000000000000..a5d4b47d6b829
--- /dev/null
+++ b/pl_examples/hpu_examples/simple_mnist/mnist.py
@@ -0,0 +1,75 @@
+# Copyright The PyTorch Lightning team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import torch
+from jsonargparse import lazy_instance
+from torch.nn import functional as F
+
+import pytorch_lightning as pl
+from pl_examples.basic_examples.mnist_datamodule import MNISTDataModule
+from pytorch_lightning.plugins import HPUPrecisionPlugin
+from pytorch_lightning.utilities.cli import LightningCLI
+
+
+class LitClassifier(pl.LightningModule):
+ def __init__(self):
+ super().__init__()
+ self.l1 = torch.nn.Linear(28 * 28, 10)
+
+ def forward(self, x):
+ return torch.relu(self.l1(x.view(x.size(0), -1)))
+
+ def training_step(self, batch, batch_idx):
+ x, y = batch
+ loss = F.cross_entropy(self(x), y)
+ return loss
+
+ def validation_step(self, batch, batch_idx):
+ x, y = batch
+ probs = self(x)
+ acc = self.accuracy(probs, y)
+ self.log("val_acc", acc)
+
+ def test_step(self, batch, batch_idx):
+ x, y = batch
+ logits = self(x)
+ acc = self.accuracy(logits, y)
+ self.log("test_acc", acc)
+
+ @staticmethod
+ def accuracy(logits, y):
+ acc = torch.sum(torch.eq(torch.argmax(logits, -1), y).to(torch.float32)) / len(y)
+ return acc
+
+ def configure_optimizers(self):
+ return torch.optim.Adam(self.parameters(), lr=0.02)
+
+
+if __name__ == "__main__":
+ cli = LightningCLI(
+ LitClassifier,
+ MNISTDataModule,
+ trainer_defaults={
+ "accelerator": "hpu",
+ "devices": 1,
+ "max_epochs": 1,
+ "plugins": lazy_instance(HPUPrecisionPlugin, precision=16),
+ },
+ run=False,
+ save_config_overwrite=True,
+ )
+
+ # Run the model ⚡
+ cli.trainer.fit(cli.model, datamodule=cli.datamodule)
+ cli.trainer.validate(cli.model, datamodule=cli.datamodule)
+ cli.trainer.test(cli.model, datamodule=cli.datamodule)
diff --git a/pl_examples/hpu_examples/simple_mnist/ops_bf16_mnist.txt b/pl_examples/hpu_examples/simple_mnist/ops_bf16_mnist.txt
new file mode 100644
index 0000000000000..53ec99c15b4ce
--- /dev/null
+++ b/pl_examples/hpu_examples/simple_mnist/ops_bf16_mnist.txt
@@ -0,0 +1,2 @@
+linear
+relu
diff --git a/pl_examples/hpu_examples/simple_mnist/ops_fp32_mnist.txt b/pl_examples/hpu_examples/simple_mnist/ops_fp32_mnist.txt
new file mode 100644
index 0000000000000..4509b7e58ac29
--- /dev/null
+++ b/pl_examples/hpu_examples/simple_mnist/ops_fp32_mnist.txt
@@ -0,0 +1 @@
+cross_entropy
diff --git a/pytorch_lightning/accelerators/__init__.py b/pytorch_lightning/accelerators/__init__.py
index fe9ae0a120cfb..1ab90e025b087 100644
--- a/pytorch_lightning/accelerators/__init__.py
+++ b/pytorch_lightning/accelerators/__init__.py
@@ -13,6 +13,7 @@
from pytorch_lightning.accelerators.accelerator import Accelerator # noqa: F401
from pytorch_lightning.accelerators.cpu import CPUAccelerator # noqa: F401
from pytorch_lightning.accelerators.gpu import GPUAccelerator # noqa: F401
+from pytorch_lightning.accelerators.hpu import HPUAccelerator # noqa: F401
from pytorch_lightning.accelerators.ipu import IPUAccelerator # noqa: F401
from pytorch_lightning.accelerators.registry import AcceleratorRegistry, call_register_accelerators # noqa: F401
from pytorch_lightning.accelerators.tpu import TPUAccelerator # noqa: F401
diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py
index 526cec3e47319..5fe6b53dd54b5 100644
--- a/pytorch_lightning/accelerators/accelerator.py
+++ b/pytorch_lightning/accelerators/accelerator.py
@@ -28,6 +28,7 @@ class Accelerator(ABC):
- GPU
- TPU
- IPU
+ - HPU
"""
def setup_environment(self, root_device: torch.device) -> None:
diff --git a/pytorch_lightning/accelerators/hpu.py b/pytorch_lightning/accelerators/hpu.py
new file mode 100644
index 0000000000000..76fdb02b307b8
--- /dev/null
+++ b/pytorch_lightning/accelerators/hpu.py
@@ -0,0 +1,69 @@
+# Copyright The PyTorch Lightning team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Any, Dict, List, Optional, Union
+
+import torch
+
+from pytorch_lightning.accelerators.accelerator import Accelerator
+from pytorch_lightning.utilities import _HPU_AVAILABLE, device_parser
+from pytorch_lightning.utilities.exceptions import MisconfigurationException
+from pytorch_lightning.utilities.rank_zero import rank_zero_debug
+
+
+class HPUAccelerator(Accelerator):
+ """Accelerator for HPU devices."""
+
+ def setup_environment(self, root_device: torch.device) -> None:
+ """
+ Raises:
+ MisconfigurationException:
+ If the selected device is not HPU.
+ """
+ super().setup_environment(root_device)
+ if root_device.type != "hpu":
+ raise MisconfigurationException(f"Device should be HPU, got {root_device} instead.")
+
+ def get_device_stats(self, device: Union[str, torch.device]) -> Dict[str, Any]:
+ """HPU device stats aren't supported yet."""
+ rank_zero_debug("HPU device stats aren't supported yet.")
+ return {}
+
+ @staticmethod
+ def parse_devices(devices: Union[int, str, List[int]]) -> Optional[int]:
+ """Accelerator device parsing logic."""
+ return device_parser.parse_hpus(devices)
+
+ @staticmethod
+ def get_parallel_devices(devices: int) -> List[torch.device]:
+ """Gets parallel devices for the Accelerator."""
+ return [torch.device("hpu")] * devices
+
+ @staticmethod
+ def auto_device_count() -> int:
+ """Get the devices when set to auto."""
+ # TODO(@kaushikb11): Update this when api is exposed by the Habana team
+ return 8
+
+ @staticmethod
+ def is_available() -> bool:
+ return _HPU_AVAILABLE
+
+ @classmethod
+ def register_accelerators(cls, accelerator_registry: Dict) -> None:
+ accelerator_registry.register(
+ "hpu",
+ cls,
+ description=f"{cls.__class__.__name__}",
+ )
diff --git a/pytorch_lightning/overrides/torch_distributed.py b/pytorch_lightning/overrides/torch_distributed.py
new file mode 100644
index 0000000000000..9c70a2867b429
--- /dev/null
+++ b/pytorch_lightning/overrides/torch_distributed.py
@@ -0,0 +1,168 @@
+# type: ignore
+
+import io
+import logging
+import os
+import pickle
+
+import torch
+from torch._C._distributed_c10d import ProcessGroup
+
+_pickler = pickle.Pickler
+_unpickler = pickle.Unpickler
+
+
+logger = logging.getLogger(__name__)
+
+if torch.distributed.is_available():
+ from torch.distributed import Backend, broadcast, get_backend, get_rank, GroupMember
+
+# The code underneath is taken from PyTorch `torch/distributed/distributed_c10d.py`
+# the distributed backend and tensor type updates for habana backend is done here before broadcast
+
+
+# Taken from https://github.com/pytorch/pytorch/blob/3466c1b6901f06a563b8cbfa3c942fa50bda835b/torch/distributed/distributed_c10d.py#L267 # noqa: E501
+def _rank_not_in_group(group: ProcessGroup):
+ """Helper that checks if the current process's rank is not in a given group."""
+ if group is None:
+ return False
+ return group == GroupMember.NON_GROUP_MEMBER
+
+
+# Taken from https://github.com/pytorch/pytorch/blob/3466c1b6901f06a563b8cbfa3c942fa50bda835b/torch/distributed/distributed_c10d.py#L1551 # noqa: E501
+def _object_to_tensor(obj):
+ f = io.BytesIO()
+ _pickler(f).dump(obj)
+ byte_storage = torch.ByteStorage.from_buffer(f.getvalue())
+ # Do not replace `torch.ByteTensor` or `torch.LongTensor` with torch.tensor and specifying dtype.
+ # Otherwise, it will casue 100X slowdown.
+ # See: https://github.com/pytorch/pytorch/issues/65696
+ byte_tensor = torch.ByteTensor(byte_storage)
+ local_size = torch.LongTensor([byte_tensor.numel()])
+ return byte_tensor, local_size
+
+
+# Taken from https://github.com/pytorch/pytorch/blob/3466c1b6901f06a563b8cbfa3c942fa50bda835b/torch/distributed/distributed_c10d.py#L1563 # noqa: E501
+def _tensor_to_object(tensor, tensor_size):
+ buf = tensor.numpy().tobytes()[:tensor_size]
+ return _unpickler(io.BytesIO(buf)).load()
+
+
+def _broadcast_object_list(object_list, src=0, group=None, device=None):
+ """Broadcasts picklable objects in ``object_list`` to the whole group. Similar to :func:`broadcast`, but Python
+ objects can be passed in. Note that all objects in ``object_list`` must be picklable in order to be
+ broadcasted.
+
+ Args:
+ object_list: List of input objects to broadcast.
+ Each object must be picklable. Only objects on the ``src`` rank will
+ be broadcast, but each rank must provide lists of equal sizes.
+ src: Source rank from which to broadcast ``object_list``.
+ group: The process group to work on. If None,
+ the default process group will be used. Default is ``None``.
+ device: If not None, the objects are
+ serialized and converted to tensors which are moved to the
+ ``device`` before broadcasting. Default is ``None``.
+
+ Returns:
+ ``None``. If rank is part of the group, ``object_list`` will contain the
+ broadcasted objects from ``src`` rank.
+
+ .. note:: For NCCL-based processed groups, internal tensor representations
+ of objects must be moved to the GPU device before communication takes
+ place. In this case, the device used is given by
+ ``torch.cuda.current_device()`` and it is the user's responsiblity to
+ ensure that this is set so that each rank has an individual GPU, via
+ ``torch.cuda.set_device()``.
+
+ .. note:: Note that this API differs slightly from the :func:`all_gather`
+ collective since it does not provide an ``async_op`` handle and thus
+ will be a blocking call.
+
+ .. warning::
+ :func:`broadcast_object_list` uses ``pickle`` module implicitly, which
+ is known to be insecure. It is possible to construct malicious pickle
+ data which will execute arbitrary code during unpickling. Only call this
+ function with data you trust.
+ """
+ if _rank_not_in_group(group):
+ return
+
+ my_rank = get_rank()
+ # Serialize object_list elements to tensors on src rank.
+ if my_rank == src:
+ tensor_list, size_list = zip(*[_object_to_tensor(obj) for obj in object_list])
+ object_sizes_tensor = torch.cat(size_list)
+ else:
+ object_sizes_tensor = torch.empty(len(object_list), dtype=torch.long)
+
+ # Current device selection.
+ # To preserve backwards compatibility, ``device`` is default to ``None``
+ # in which case we run current logic of device selection, i.e.
+ # ``current_device`` is CUDA if backend is NCCL otherwise CPU device. In the
+ # case it is not ``None`` we move the size and object tensors to be
+ # broadcasted to this device.
+ group_backend = get_backend(group)
+ is_nccl_backend = group_backend == Backend.NCCL
+ is_hpu_backend = os.environ.get("HCCL_DISTRIBUTED_BACKEND") == "1"
+ current_device = None
+ if device is not None:
+ if is_nccl_backend and device.type != "cuda":
+ raise ValueError("device type must be cuda for nccl backend")
+ current_device = device
+ else:
+ current_device = torch.device("cpu")
+ if is_nccl_backend:
+ # See note about using torch.cuda.current_device() here in
+ # docstring. We cannot simply use my_rank since rank == device is
+ # not necessarily true.
+ current_device = torch.device("cuda", torch.cuda.current_device())
+ if is_nccl_backend:
+ object_sizes_tensor = object_sizes_tensor.to(current_device)
+
+ elif is_hpu_backend:
+ current_device = torch.device("hpu")
+ # Workaround: HPU doesn't not support long tensors for collectives
+ if (object_sizes_tensor.type() == "torch.LongTensor") or (object_sizes_tensor.type() == "torch.hpu.LongTensor"):
+ object_sizes_tensor = object_sizes_tensor.int()
+ else:
+ print("unhandled hpu object_sizes_tensor type :: ", object_sizes_tensor.type())
+ object_sizes_tensor = object_sizes_tensor.to(current_device)
+
+ # Broadcast object sizes
+ broadcast(object_sizes_tensor, src=src, group=group)
+
+ # Concatenate and broadcast serialized object tensors
+ if my_rank == src:
+ object_tensor = torch.cat(tensor_list)
+ else:
+ object_tensor = torch.empty(
+ torch.sum(object_sizes_tensor).int().item(),
+ dtype=torch.uint8,
+ )
+
+ if is_nccl_backend or is_hpu_backend:
+ object_tensor = object_tensor.to(current_device)
+
+ broadcast(object_tensor, src=src, group=group)
+ # Deserialize objects using their stored sizes.
+ offset = 0
+ if my_rank != src:
+ for i, obj_size in enumerate(object_sizes_tensor):
+ obj_view = object_tensor[offset : offset + obj_size]
+ obj_view = obj_view.type(torch.uint8)
+ if obj_view.device != torch.device("cpu"):
+ obj_view = obj_view.cpu()
+ offset += obj_size
+ object_list[i] = _tensor_to_object(obj_view, obj_size)
+
+
+if not torch.distributed.is_available():
+ # avoid failures on early PyTorch versions for Windows where
+ # not all functions used in `broadcast_object_list` are available.
+ def _broadcast_noop(obj, *_, **__):
+ return obj
+
+ broadcast_object_list = _broadcast_noop
+else:
+ broadcast_object_list = _broadcast_object_list
diff --git a/pytorch_lightning/plugins/__init__.py b/pytorch_lightning/plugins/__init__.py
index 2e4e0da6b2755..0f1c4ca85ed5a 100644
--- a/pytorch_lightning/plugins/__init__.py
+++ b/pytorch_lightning/plugins/__init__.py
@@ -2,6 +2,7 @@
from pytorch_lightning.plugins.environments import ClusterEnvironment
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
+from pytorch_lightning.plugins.io.hpu_plugin import HPUCheckpointIO
from pytorch_lightning.plugins.io.torch_plugin import TorchCheckpointIO
from pytorch_lightning.plugins.io.xla_plugin import XLACheckpointIO
from pytorch_lightning.plugins.layer_sync import LayerSync, NativeSyncBatchNorm
@@ -9,6 +10,7 @@
from pytorch_lightning.plugins.precision.deepspeed import DeepSpeedPrecisionPlugin
from pytorch_lightning.plugins.precision.double import DoublePrecisionPlugin
from pytorch_lightning.plugins.precision.fully_sharded_native_amp import FullyShardedNativeMixedPrecisionPlugin
+from pytorch_lightning.plugins.precision.hpu import HPUPrecisionPlugin
from pytorch_lightning.plugins.precision.ipu import IPUPrecisionPlugin
from pytorch_lightning.plugins.precision.native_amp import NativeMixedPrecisionPlugin
from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin
@@ -39,6 +41,7 @@
"CheckpointIO",
"TorchCheckpointIO",
"XLACheckpointIO",
+ "HPUCheckpointIO",
"ApexMixedPrecisionPlugin",
"DataParallelPlugin",
"DDP2Plugin",
@@ -51,6 +54,7 @@
"HorovodPlugin",
"IPUPlugin",
"IPUPrecisionPlugin",
+ "HPUPrecisionPlugin",
"NativeMixedPrecisionPlugin",
"PrecisionPlugin",
"ShardedNativeMixedPrecisionPlugin",
diff --git a/pytorch_lightning/plugins/io/__init__.py b/pytorch_lightning/plugins/io/__init__.py
index 1b14eee6ec4f2..abd196eb2b1e3 100644
--- a/pytorch_lightning/plugins/io/__init__.py
+++ b/pytorch_lightning/plugins/io/__init__.py
@@ -12,5 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO # noqa: F401
+from pytorch_lightning.plugins.io.hpu_plugin import HPUCheckpointIO # noqa: F401
from pytorch_lightning.plugins.io.torch_plugin import TorchCheckpointIO # noqa: F401
from pytorch_lightning.plugins.io.xla_plugin import XLACheckpointIO # noqa: F401
diff --git a/pytorch_lightning/plugins/io/hpu_plugin.py b/pytorch_lightning/plugins/io/hpu_plugin.py
new file mode 100644
index 0000000000000..c72d1d9fcd112
--- /dev/null
+++ b/pytorch_lightning/plugins/io/hpu_plugin.py
@@ -0,0 +1,52 @@
+# Copyright The PyTorch Lightning team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+from typing import Any, Dict, Optional
+
+import torch
+
+from pytorch_lightning.plugins.io.torch_plugin import TorchCheckpointIO
+from pytorch_lightning.utilities.apply_func import move_data_to_device
+from pytorch_lightning.utilities.cloud_io import atomic_save, get_filesystem
+from pytorch_lightning.utilities.types import _PATH
+
+
+class HPUCheckpointIO(TorchCheckpointIO):
+ """CheckpointIO to save checkpoints for HPU training strategies."""
+
+ def save_checkpoint(self, checkpoint: Dict[str, Any], path: _PATH, storage_options: Optional[Any] = None) -> None:
+ """Save model/training states as a checkpoint file through state-dump and file-write.
+
+ Args:
+ checkpoint: dict containing model and trainer state
+ path: write-target path
+ storage_options: not used in ``XLACheckpointIO.save_checkpoint``
+
+ Raises:
+ TypeError:
+ If ``storage_options`` arg is passed in
+ """
+ if storage_options is not None:
+ raise TypeError(
+ "`Trainer.save_checkpoint(..., storage_options=...)` with `storage_options` arg"
+ f" is not supported for `{self.__class__.__name__}`. Please implement your custom `CheckpointIO`"
+ " to define how you'd like to use `storage_options`."
+ )
+ fs = get_filesystem(path)
+ fs.makedirs(os.path.dirname(path), exist_ok=True)
+
+ checkpoint = move_data_to_device(checkpoint, torch.device("cpu"))
+ # write the checkpoint dictionary to the provided path
+ atomic_save(checkpoint, path)
diff --git a/pytorch_lightning/plugins/precision/__init__.py b/pytorch_lightning/plugins/precision/__init__.py
index b407e47ca9337..4bc29c1be1864 100644
--- a/pytorch_lightning/plugins/precision/__init__.py
+++ b/pytorch_lightning/plugins/precision/__init__.py
@@ -1,9 +1,23 @@
+# Copyright The PyTorch Lightning team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
from pytorch_lightning.plugins.precision.apex_amp import ApexMixedPrecisionPlugin # noqa: F401
from pytorch_lightning.plugins.precision.deepspeed import DeepSpeedPrecisionPlugin # noqa: F401
from pytorch_lightning.plugins.precision.double import DoublePrecisionPlugin # noqa: F401
from pytorch_lightning.plugins.precision.fully_sharded_native_amp import ( # noqa: F401
FullyShardedNativeMixedPrecisionPlugin,
)
+from pytorch_lightning.plugins.precision.hpu import HPUPrecisionPlugin # noqa: F401
from pytorch_lightning.plugins.precision.ipu import IPUPrecisionPlugin # noqa: F401
from pytorch_lightning.plugins.precision.mixed import MixedPrecisionPlugin # noqa: F401
from pytorch_lightning.plugins.precision.native_amp import NativeMixedPrecisionPlugin # noqa: F401
diff --git a/pytorch_lightning/plugins/precision/hpu.py b/pytorch_lightning/plugins/precision/hpu.py
new file mode 100644
index 0000000000000..3c02d82a2de10
--- /dev/null
+++ b/pytorch_lightning/plugins/precision/hpu.py
@@ -0,0 +1,55 @@
+# Copyright The PyTorch Lightning team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import Optional, Union
+
+from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin
+from pytorch_lightning.utilities.exceptions import MisconfigurationException
+from pytorch_lightning.utilities.imports import _HPU_AVAILABLE
+
+if _HPU_AVAILABLE:
+ from habana_frameworks.torch.hpex import hmp
+
+
+class HPUPrecisionPlugin(PrecisionPlugin):
+ """Plugin that enables bfloat/half support on HPUs.
+
+ Args:
+ precision: The precision to use.
+ opt_level: Choose optimization level for hmp.
+ bf16_file_path: Path to bf16 ops list in hmp O1 mode.
+ fp32_file_path: Path to fp32 ops list in hmp O1 mode.
+ verbose: Enable verbose mode for hmp.
+ """
+
+ def __init__(
+ self,
+ precision: Union[str, int],
+ opt_level: str = "O2",
+ bf16_file_path: Optional[str] = None,
+ fp32_file_path: Optional[str] = None,
+ verbose: bool = False,
+ ) -> None:
+ if not _HPU_AVAILABLE:
+ raise MisconfigurationException("HPU precision plugin requires HPU devices.")
+ supported_precision_values = (16, 32, "bf16")
+ if precision not in supported_precision_values:
+ raise ValueError(
+ f"`Trainer(accelerator='hpu', precision={precision!r})` is not supported."
+ f" `precision` must be one of: {supported_precision_values}."
+ )
+ super().__init__()
+ self.precision = precision
+ hmp.convert(
+ opt_level=opt_level, bf16_file_path=bf16_file_path, fp32_file_path=fp32_file_path, isVerbose=verbose
+ )
diff --git a/pytorch_lightning/strategies/__init__.py b/pytorch_lightning/strategies/__init__.py
index a4cd57a50ac1d..38a2b466e57e9 100644
--- a/pytorch_lightning/strategies/__init__.py
+++ b/pytorch_lightning/strategies/__init__.py
@@ -19,11 +19,13 @@
from pytorch_lightning.strategies.dp import DataParallelStrategy # noqa: F401
from pytorch_lightning.strategies.fully_sharded import DDPFullyShardedStrategy # noqa: F401
from pytorch_lightning.strategies.horovod import HorovodStrategy # noqa: F401
+from pytorch_lightning.strategies.hpu_parallel import HPUParallelStrategy # noqa: F401
from pytorch_lightning.strategies.ipu import IPUStrategy # noqa: F401
from pytorch_lightning.strategies.parallel import ParallelStrategy # noqa: F401
from pytorch_lightning.strategies.sharded import DDPShardedStrategy # noqa: F401
from pytorch_lightning.strategies.sharded_spawn import DDPSpawnShardedStrategy # noqa: F401
from pytorch_lightning.strategies.single_device import SingleDeviceStrategy # noqa: F401
+from pytorch_lightning.strategies.single_hpu import SingleHPUStrategy # noqa: F401
from pytorch_lightning.strategies.single_tpu import SingleTPUStrategy # noqa: F401
from pytorch_lightning.strategies.strategy import Strategy # noqa: F401
from pytorch_lightning.strategies.strategy_registry import call_register_strategies, StrategyRegistry # noqa: F401
diff --git a/pytorch_lightning/strategies/hpu_parallel.py b/pytorch_lightning/strategies/hpu_parallel.py
new file mode 100644
index 0000000000000..562a841b89510
--- /dev/null
+++ b/pytorch_lightning/strategies/hpu_parallel.py
@@ -0,0 +1,132 @@
+# Copyright The PyTorch Lightning team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+import os
+from typing import Dict, List, Optional
+
+import torch
+import torch.distributed
+
+import pytorch_lightning as pl
+from pytorch_lightning.overrides import LightningDistributedModule
+from pytorch_lightning.overrides.torch_distributed import broadcast_object_list
+from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
+from pytorch_lightning.plugins.io.hpu_plugin import HPUCheckpointIO
+from pytorch_lightning.plugins.precision import PrecisionPlugin
+from pytorch_lightning.strategies.ddp import DDPStrategy
+from pytorch_lightning.utilities.distributed import group as _group
+from pytorch_lightning.utilities.exceptions import MisconfigurationException
+from pytorch_lightning.utilities.imports import _HPU_AVAILABLE, _TORCH_LESSER_EQUAL_1_10_2
+
+if _HPU_AVAILABLE:
+ import habana_frameworks.torch.core.hccl # noqa: F401
+ from habana_frameworks.torch.utils.library_loader import load_habana_module
+
+log = logging.getLogger(__name__)
+
+
+class HPUParallelStrategy(DDPStrategy):
+ """Strategy for distributed training on multiple HPU devices."""
+
+ strategy_name = "hpu_parallel"
+
+ def __init__(
+ self,
+ accelerator: Optional["pl.accelerators.accelerator.Accelerator"] = None,
+ parallel_devices: Optional[List[torch.device]] = None,
+ checkpoint_io: Optional[CheckpointIO] = None,
+ precision_plugin: Optional[PrecisionPlugin] = None,
+ process_group_backend: Optional[str] = "hccl",
+ ) -> None:
+
+ if not _HPU_AVAILABLE:
+ raise MisconfigurationException("`HPUParallelStrategy` requires HPU devices to run")
+
+ super().__init__(
+ accelerator=accelerator,
+ parallel_devices=parallel_devices,
+ checkpoint_io=checkpoint_io or HPUCheckpointIO(),
+ precision_plugin=precision_plugin,
+ process_group_backend=process_group_backend,
+ )
+
+ def setup_environment(self) -> None:
+ # This function is used to load Habana libraries required for PyTorch
+ # to register HPU as one of the available devices.
+ load_habana_module()
+
+ os.environ["ID"] = str(self.local_rank)
+ if self._process_group_backend == "hccl":
+ # this env is used in overrides to check the backend initiated
+ os.environ["HCCL_DISTRIBUTED_BACKEND"] = str(1)
+ super().setup_environment()
+
+ def determine_ddp_device_ids(self) -> None:
+ return None
+
+ def pre_configure_ddp(self): # type: ignore
+ # if unset, default `find_unused_parameters` `True`
+ # Many models require setting this parameter to True, as there are corner cases
+ # when not all parameter backward hooks are fired by the autograd engine even if require_grad is set to True.
+ # This flag does come with a performance hit, so it is suggested to disable in cases where it is possible.
+ self._ddp_kwargs["find_unused_parameters"] = self._ddp_kwargs.get("find_unused_parameters", True)
+
+ self._static_graph = False
+ static_graph = self._ddp_kwargs.get("static_graph")
+ if static_graph:
+ # when _set_static_graph() is called find_unused_parameters does not have any significance.
+ # Resetting the value of find_unused_parameters to False which is the default value to DDP
+ self._ddp_kwargs["find_unused_parameters"] = False
+ self._static_graph = True
+ if static_graph is not None:
+ # DDP does not accept static_graph as a parameter, hence removing it from the list
+ del self._ddp_kwargs["static_graph"]
+
+ def configure_ddp(self) -> None:
+ # DDP does not accept static graph as param with torch < 1.11
+ if _TORCH_LESSER_EQUAL_1_10_2:
+ log.detail(f"{self.__class__.__name__}: configuring DistributedDataParallel")
+ self.pre_configure_ddp()
+ self.model = self._setup_model(LightningDistributedModule(self.model)) # type: ignore
+ if self.root_device.type == "hpu" and self._static_graph:
+ self._model._set_static_graph() # type: ignore
+ self._register_ddp_hooks()
+ else:
+ self.configure_ddp()
+
+ def broadcast(self, obj: object, src: int = 0) -> object: # type: ignore
+ obj = [obj]
+ if self.global_rank != src:
+ obj = [None]
+
+ broadcast_object_list(obj, src, group=_group.WORLD)
+ return obj[0]
+
+ def teardown(self) -> None:
+ log.detail(f"{self.__class__.__name__}: tearing down strategy.")
+ super().teardown()
+
+ log.detail(f"{self.__class__.__name__}: moving model to CPU")
+ self.lightning_module.cpu() # type: ignore
+ # Was set to local rank
+ os.environ.pop("ID", None)
+ os.environ.pop("HCCL_DISTRIBUTED_BACKEND", None)
+
+ @classmethod
+ def register_strategies(cls, strategy_registry: Dict) -> None:
+ strategy_registry.register(
+ cls.strategy_name,
+ cls,
+ description=f"{cls.__class__.__name__}",
+ )
diff --git a/pytorch_lightning/strategies/single_hpu.py b/pytorch_lightning/strategies/single_hpu.py
new file mode 100644
index 0000000000000..edafe63441906
--- /dev/null
+++ b/pytorch_lightning/strategies/single_hpu.py
@@ -0,0 +1,80 @@
+# Copyright The PyTorch Lightning team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Dict, Optional
+
+import pytorch_lightning as pl
+from pytorch_lightning.plugins.io.hpu_plugin import HPUCheckpointIO
+from pytorch_lightning.plugins.precision import PrecisionPlugin
+from pytorch_lightning.strategies.single_device import SingleDeviceStrategy
+from pytorch_lightning.utilities import _HPU_AVAILABLE
+from pytorch_lightning.utilities.exceptions import MisconfigurationException
+from pytorch_lightning.utilities.types import _DEVICE
+
+if _HPU_AVAILABLE:
+ import habana_frameworks.torch.core.hccl # noqa: F401
+ from habana_frameworks.torch.utils.library_loader import load_habana_module
+
+
+class SingleHPUStrategy(SingleDeviceStrategy):
+ """Strategy for training on single HPU device."""
+
+ strategy_name = "hpu_single"
+
+ def __init__(
+ self,
+ device: _DEVICE = "hpu",
+ accelerator: Optional["pl.accelerators.accelerator.Accelerator"] = None,
+ checkpoint_io: Optional[HPUCheckpointIO] = None,
+ precision_plugin: Optional[PrecisionPlugin] = None,
+ ):
+
+ if not _HPU_AVAILABLE:
+ raise MisconfigurationException("`SingleHPUStrategy` requires HPU devices to run")
+
+ # This function is used to load Habana libraries required for PyTorch
+ # to register HPU as one of the available devices.
+ load_habana_module()
+
+ super().__init__(
+ accelerator=accelerator,
+ device=device,
+ checkpoint_io=checkpoint_io or HPUCheckpointIO(),
+ precision_plugin=precision_plugin,
+ )
+
+ @property
+ def is_distributed(self) -> bool:
+ return False
+
+ def setup(self, trainer: "pl.Trainer") -> None:
+ self.model_to_device()
+ super().setup(trainer)
+
+ def setup_optimizers(self, trainer: "pl.Trainer") -> None:
+ super().setup_optimizers(trainer)
+
+ if len(self.optimizers) > 1:
+ raise MisconfigurationException("HPUs currently support only one optimizer.")
+
+ def model_to_device(self) -> None:
+ self.model.to(self.root_device) # type: ignore
+
+ @classmethod
+ def register_strategies(cls, strategy_registry: Dict) -> None:
+ strategy_registry.register(
+ cls.strategy_name,
+ cls,
+ description=f"{cls.__class__.__name__}",
+ )
diff --git a/pytorch_lightning/trainer/connectors/accelerator_connector.py b/pytorch_lightning/trainer/connectors/accelerator_connector.py
index 43f87ec731acf..28c6da86c6d59 100644
--- a/pytorch_lightning/trainer/connectors/accelerator_connector.py
+++ b/pytorch_lightning/trainer/connectors/accelerator_connector.py
@@ -22,6 +22,7 @@
from pytorch_lightning.accelerators.accelerator import Accelerator
from pytorch_lightning.accelerators.cpu import CPUAccelerator
from pytorch_lightning.accelerators.gpu import GPUAccelerator
+from pytorch_lightning.accelerators.hpu import HPUAccelerator
from pytorch_lightning.accelerators.ipu import IPUAccelerator
from pytorch_lightning.accelerators.registry import AcceleratorRegistry
from pytorch_lightning.accelerators.tpu import TPUAccelerator
@@ -31,6 +32,7 @@
DeepSpeedPrecisionPlugin,
DoublePrecisionPlugin,
FullyShardedNativeMixedPrecisionPlugin,
+ HPUPrecisionPlugin,
IPUPrecisionPlugin,
NativeMixedPrecisionPlugin,
PLUGIN_INPUT,
@@ -58,8 +60,10 @@
DDPStrategy,
DeepSpeedStrategy,
HorovodStrategy,
+ HPUParallelStrategy,
IPUStrategy,
SingleDeviceStrategy,
+ SingleHPUStrategy,
SingleTPUStrategy,
Strategy,
StrategyRegistry,
@@ -77,6 +81,7 @@
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import (
_HOROVOD_AVAILABLE,
+ _HPU_AVAILABLE,
_IPU_AVAILABLE,
_TORCH_GREATER_EQUAL_1_8,
_TPU_AVAILABLE,
@@ -187,7 +192,6 @@ def __init__(
self._check_device_config_and_set_final_flags(
devices=devices, num_nodes=num_nodes, num_processes=num_processes, gpus=gpus, ipus=ipus, tpu_cores=tpu_cores
)
-
# 2. Instantiate Accelerator
# handle `auto` and `None`
self._set_accelerator_if_ipu_strategy_is_passed()
@@ -427,7 +431,7 @@ def _check_device_config_and_set_final_flags(
if self._devices_flag == "auto" and self._accelerator_flag is None:
raise MisconfigurationException(
f"You passed `devices={devices}` but haven't specified"
- " `accelerator=('auto'|'tpu'|'gpu'|'ipu'|'cpu')` for the devices mapping"
+ " `accelerator=('auto'|'tpu'|'gpu'|'ipu'|'cpu'|'hpu)` for the devices mapping"
)
def _map_deprecated_devices_specfic_info_to_accelerator_and_device_flag(
@@ -481,6 +485,8 @@ def _choose_accelerator(self) -> str:
return "tpu"
if _IPU_AVAILABLE:
return "ipu"
+ if _HPU_AVAILABLE:
+ return "hpu"
if torch.cuda.is_available() and torch.cuda.device_count() > 0:
return "gpu"
return "cpu"
@@ -545,6 +551,11 @@ def _is_slurm_managing_tasks(self) -> bool:
def _choose_strategy(self) -> Union[Strategy, str]:
if self._accelerator_flag == "ipu":
return IPUStrategy.strategy_name
+ if self._accelerator_flag == "hpu":
+ if self._parallel_devices and len(self._parallel_devices) > 1:
+ return HPUParallelStrategy.strategy_name
+ else:
+ return SingleHPUStrategy(device=torch.device("hpu"))
if self._accelerator_flag == "tpu":
if self._parallel_devices and len(self._parallel_devices) > 1:
return TPUSpawnStrategy.strategy_name
@@ -642,6 +653,8 @@ def _check_and_init_precision(self) -> PrecisionPlugin:
if isinstance(self.accelerator, IPUAccelerator):
return IPUPrecisionPlugin(self._precision_flag) # type: ignore
+ if isinstance(self.accelerator, HPUAccelerator):
+ return HPUPrecisionPlugin(self._precision_flag) # type: ignore
if isinstance(self.accelerator, TPUAccelerator):
if self._precision_flag == 32:
return TPUPrecisionPlugin()
@@ -707,6 +720,11 @@ def _validate_precision_choice(self) -> None:
f"The `TPUAccelerator` can only be used with a `TPUPrecisionPlugin`,"
f" found: {self._precision_plugin_flag}."
)
+ if isinstance(self.accelerator, HPUAccelerator):
+ if self._precision_flag not in (16, "bf16", 32):
+ raise MisconfigurationException(
+ f"`Trainer(accelerator='hpu', precision={self._precision_flag!r})` is not supported."
+ )
if (
self._precision_flag == 16
and isinstance(self.accelerator, CPUAccelerator)
@@ -768,7 +786,15 @@ def _lazy_init_strategy(self) -> None:
):
raise ValueError(
"The `TPUAccelerator` can only be used with a `SingleTPUStrategy` or `TPUSpawnStrategy`,"
- f" found {self.strategy}."
+ f" found {self.strategy.__class__.__name__}."
+ )
+
+ if isinstance(self.accelerator, HPUAccelerator) and not isinstance(
+ self.strategy, (SingleHPUStrategy, HPUParallelStrategy)
+ ):
+ raise ValueError(
+ "The `HPUAccelerator` can only be used with a `SingleHPUStrategy` or `HPUParallelStrategy`,"
+ f" found {self.strategy.__class__.__name__}."
)
"""The following properties are here for backward-compatibility and will be deprecated and removed in favor
@@ -801,6 +827,7 @@ def is_distributed(self) -> bool:
DeepSpeedStrategy,
TPUSpawnStrategy,
HorovodStrategy,
+ HPUParallelStrategy,
)
is_distributed = isinstance(self.strategy, distributed_strategy)
if isinstance(self.accelerator, TPUAccelerator):
diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py
index fab1a5e9b9d83..bcda0764ff11f 100644
--- a/pytorch_lightning/trainer/trainer.py
+++ b/pytorch_lightning/trainer/trainer.py
@@ -31,7 +31,7 @@
from torch.utils.data import DataLoader
import pytorch_lightning as pl
-from pytorch_lightning.accelerators import Accelerator, GPUAccelerator, IPUAccelerator, TPUAccelerator
+from pytorch_lightning.accelerators import Accelerator, GPUAccelerator, HPUAccelerator, IPUAccelerator, TPUAccelerator
from pytorch_lightning.callbacks import Callback, EarlyStopping, ModelCheckpoint, ProgressBarBase
from pytorch_lightning.callbacks.prediction_writer import BasePredictionWriter
from pytorch_lightning.core.datamodule import LightningDataModule
@@ -77,6 +77,7 @@
from pytorch_lightning.tuner.lr_finder import _LRFinder
from pytorch_lightning.tuner.tuning import Tuner
from pytorch_lightning.utilities import (
+ _HPU_AVAILABLE,
_IPU_AVAILABLE,
_TPU_AVAILABLE,
AMPType,
@@ -195,7 +196,7 @@ def __init__(
Args:
- accelerator: Supports passing different accelerator types ("cpu", "gpu", "tpu", "ipu", "auto")
+ accelerator: Supports passing different accelerator types ("cpu", "gpu", "tpu", "ipu", "hpu", "auto")
as well as custom accelerator instances.
.. deprecated:: v1.5
@@ -353,7 +354,7 @@ def __init__(
Default: ``None``.
precision: Double precision (64), full precision (32), half precision (16) or bfloat16 precision (bf16).
- Can be used on CPU, GPU, TPUs or IPUs.
+ Can be used on CPU, GPU, TPUs, HPUs or IPUs.
Default: ``32``.
max_epochs: Stop training once this number of epochs is reached. Disabled by default (None).
@@ -1818,6 +1819,9 @@ def _log_device_info(self) -> None:
num_ipus = self.num_devices if isinstance(self.accelerator, IPUAccelerator) else 0
rank_zero_info(f"IPU available: {_IPU_AVAILABLE}, using: {num_ipus} IPUs")
+ num_hpus = self.num_devices if isinstance(self.accelerator, HPUAccelerator) else 0
+ rank_zero_info(f"HPU available: {_HPU_AVAILABLE}, using: {num_hpus} HPUs")
+
if torch.cuda.is_available() and not isinstance(self.accelerator, GPUAccelerator):
rank_zero_warn(
"GPU available but not used. Set `accelerator` and `devices` using"
@@ -1837,6 +1841,12 @@ def _log_device_info(self) -> None:
f" `Trainer(accelerator='ipu', devices={IPUAccelerator.auto_device_count()})`."
)
+ if _HPU_AVAILABLE and not isinstance(self.accelerator, HPUAccelerator):
+ rank_zero_warn(
+ "HPU available but not used. Set `accelerator` and `devices` using"
+ f" `Trainer(accelerator='hpu', devices={HPUAccelerator.auto_device_count()})`."
+ )
+
"""
Data loading methods
"""
diff --git a/pytorch_lightning/utilities/__init__.py b/pytorch_lightning/utilities/__init__.py
index 916532d964952..930467e50108d 100644
--- a/pytorch_lightning/utilities/__init__.py
+++ b/pytorch_lightning/utilities/__init__.py
@@ -36,6 +36,7 @@
_FAIRSCALE_OSS_FP16_BROADCAST_AVAILABLE,
_GROUP_AVAILABLE,
_HOROVOD_AVAILABLE,
+ _HPU_AVAILABLE,
_HYDRA_AVAILABLE,
_HYDRA_EXPERIMENTAL_AVAILABLE,
_IPU_AVAILABLE,
diff --git a/pytorch_lightning/utilities/device_parser.py b/pytorch_lightning/utilities/device_parser.py
index d7b8a319ea4d2..5f886a4e36598 100644
--- a/pytorch_lightning/utilities/device_parser.py
+++ b/pytorch_lightning/utilities/device_parser.py
@@ -243,3 +243,24 @@ def _parse_tpu_cores_str(tpu_cores: str) -> Union[int, List[int]]:
if tpu_cores in ("1", "8"):
return int(tpu_cores)
return [int(x.strip()) for x in tpu_cores.split(",") if len(x) > 0]
+
+
+def parse_hpus(devices: Optional[Union[int, str, List[int]]]) -> Optional[int]:
+ """
+ Parses the hpus given in the format as accepted by the
+ :class:`~pytorch_lightning.trainer.Trainer` for the `devices` flag.
+
+ Args:
+ devices: An integer that indicates the number of Gaudi devices to be used
+
+ Returns:
+ Either an integer or ``None`` if no devices were requested
+
+ Raises:
+ MisconfigurationException:
+ If devices aren't of type `int` or `str`
+ """
+ if devices is not None and not isinstance(devices, (int, str)):
+ raise MisconfigurationException("`devices` for `HPUAccelerator` must be int, string or None.")
+
+ return int(devices) if isinstance(devices, str) else devices
diff --git a/pytorch_lightning/utilities/distributed.py b/pytorch_lightning/utilities/distributed.py
index d5dca963f12c4..b94351b22a335 100644
--- a/pytorch_lightning/utilities/distributed.py
+++ b/pytorch_lightning/utilities/distributed.py
@@ -21,7 +21,12 @@
from torch.nn.parallel.distributed import DistributedDataParallel
import pytorch_lightning as pl
-from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_8, _TORCH_GREATER_EQUAL_1_9, _TPU_AVAILABLE
+from pytorch_lightning.utilities.imports import (
+ _HPU_AVAILABLE,
+ _TORCH_GREATER_EQUAL_1_8,
+ _TORCH_GREATER_EQUAL_1_9,
+ _TPU_AVAILABLE,
+)
from pytorch_lightning.utilities.rank_zero import rank_zero_debug as new_rank_zero_debug
from pytorch_lightning.utilities.rank_zero import rank_zero_only # noqa: F401
from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation
@@ -124,6 +129,14 @@ def sync_ddp(
else:
op = reduce_op
+ # WA for HPU. HPU doesn't support Long types, forcefully set it to float
+ if _HPU_AVAILABLE:
+ is_hpu_backend = os.environ.get("HCCL_DISTRIBUTED_BACKEND") == "1"
+ if is_hpu_backend:
+ if (result.type() == "torch.LongTensor") or (result.type() == "torch.hpu.LongTensor"):
+ new_rank_zero_info("Long tensor unsupported on HPU, casting to float")
+ result = result.float()
+
# sync all processes before reduction
torch.distributed.barrier(group=group)
torch.distributed.all_reduce(result, op=op, group=group, async_op=False)
diff --git a/pytorch_lightning/utilities/enums.py b/pytorch_lightning/utilities/enums.py
index 105b167a29910..f9ae0c82d0444 100644
--- a/pytorch_lightning/utilities/enums.py
+++ b/pytorch_lightning/utilities/enums.py
@@ -123,6 +123,7 @@ class DistributedType(LightningEnum, metaclass=_OnAccessEnumMeta):
DDP_SHARDED = "ddp_sharded"
DDP_SHARDED_SPAWN = "ddp_sharded_spawn"
DDP_FULLY_SHARDED = "ddp_fully_sharded"
+ HPU_PARALLEL = "hpu_parallel"
@staticmethod
def interactive_compatible_types() -> list[DistributedType]:
@@ -248,6 +249,7 @@ class _StrategyType(LightningEnum):
DDP_SHARDED_SPAWN = "ddp_sharded_spawn"
DDP_FULLY_SHARDED = "ddp_fully_sharded"
BAGUA = "bagua"
+ HPU_PARALLEL = "hpu_parallel"
@staticmethod
def interactive_compatible_types() -> list[_StrategyType]:
@@ -279,6 +281,7 @@ class _AcceleratorType(LightningEnum):
GPU = "GPU"
IPU = "IPU"
TPU = "TPU"
+ HPU = "HPU"
class _FaultTolerantMode(LightningEnum):
diff --git a/pytorch_lightning/utilities/imports.py b/pytorch_lightning/utilities/imports.py
index 8795fa66d5fd3..c98874bcbd742 100644
--- a/pytorch_lightning/utilities/imports.py
+++ b/pytorch_lightning/utilities/imports.py
@@ -94,6 +94,7 @@ def _compare_version(package: str, op: Callable, version: str, use_base_version:
_TORCH_GREATER_EQUAL_1_9 = _compare_version("torch", operator.ge, "1.9.0")
_TORCH_GREATER_EQUAL_1_9_1 = _compare_version("torch", operator.ge, "1.9.1")
_TORCH_GREATER_EQUAL_1_10 = _compare_version("torch", operator.ge, "1.10.0")
+_TORCH_LESSER_EQUAL_1_10_2 = _compare_version("torch", operator.le, "1.10.2")
_TORCH_GREATER_EQUAL_1_11 = _compare_version("torch", operator.ge, "1.11.0")
_APEX_AVAILABLE = _module_available("apex.amp")
@@ -112,6 +113,7 @@ def _compare_version(package: str, op: Callable, version: str, use_base_version:
_NEPTUNE_GREATER_EQUAL_0_9 = _NEPTUNE_AVAILABLE and _compare_version("neptune", operator.ge, "0.9.0")
_OMEGACONF_AVAILABLE = _package_available("omegaconf")
_POPTORCH_AVAILABLE = _package_available("poptorch")
+_HABANA_FRAMEWORK_AVAILABLE = _package_available("habana_frameworks")
_RICH_AVAILABLE = _package_available("rich") and _compare_version("rich", operator.ge, "10.2.2")
_TORCH_QUANTIZE_AVAILABLE = bool([eg for eg in torch.backends.quantized.supported_engines if eg != "none"])
_TORCHTEXT_AVAILABLE = _package_available("torchtext")
@@ -134,6 +136,13 @@ def _compare_version(package: str, op: Callable, version: str, use_base_version:
else:
_IPU_AVAILABLE = False
+if _HABANA_FRAMEWORK_AVAILABLE:
+ from habana_frameworks.torch.utils.library_loader import is_habana_avaialble
+
+ _HPU_AVAILABLE = is_habana_avaialble()
+else:
+ _HPU_AVAILABLE = False
+
# experimental feature within PyTorch Lightning.
def _fault_tolerant_training() -> bool:
diff --git a/requirements/test.txt b/requirements/test.txt
index c7d2860ebee61..51d9ecf71db44 100644
--- a/requirements/test.txt
+++ b/requirements/test.txt
@@ -6,6 +6,9 @@ twine==3.2
mypy>=0.920
flake8>=3.9.2
pre-commit>=1.0
+pytest-forked
+sklearn
+jsonargparse
# needed in tests
cloudpickle>=1.3
diff --git a/tests/accelerators/test_accelerator_connector.py b/tests/accelerators/test_accelerator_connector.py
index 794cb9b2922cd..8c29dfb09a0d0 100644
--- a/tests/accelerators/test_accelerator_connector.py
+++ b/tests/accelerators/test_accelerator_connector.py
@@ -925,7 +925,10 @@ def test_unsupported_ipu_choice(mock_ipu_acc_avail, monkeypatch):
@mock.patch("torch.cuda.is_available", return_value=False)
@mock.patch("pytorch_lightning.utilities.imports._TPU_AVAILABLE", return_value=False)
@mock.patch("pytorch_lightning.utilities.imports._IPU_AVAILABLE", return_value=False)
-def test_devices_auto_choice_cpu(is_ipu_available_mock, is_tpu_available_mock, is_gpu_available_mock):
+@mock.patch("pytorch_lightning.utilities.imports._HPU_AVAILABLE", return_value=False)
+def test_devices_auto_choice_cpu(
+ is_ipu_available_mock, is_tpu_available_mock, is_gpu_available_mock, is_hpu_available_mock
+):
trainer = Trainer(accelerator="auto", devices="auto")
assert trainer.num_devices == 1
diff --git a/tests/accelerators/test_accelerator_registry.py b/tests/accelerators/test_accelerator_registry.py
index b783f63b1e64a..4e2b521873408 100644
--- a/tests/accelerators/test_accelerator_registry.py
+++ b/tests/accelerators/test_accelerator_registry.py
@@ -63,4 +63,4 @@ def is_available():
def test_available_accelerators_in_registry():
- assert AcceleratorRegistry.available_accelerators() == ["cpu", "gpu", "ipu", "tpu"]
+ assert AcceleratorRegistry.available_accelerators() == ["cpu", "gpu", "hpu", "ipu", "tpu"]
diff --git a/tests/accelerators/test_hpu.py b/tests/accelerators/test_hpu.py
new file mode 100644
index 0000000000000..64415eeb5a0c8
--- /dev/null
+++ b/tests/accelerators/test_hpu.py
@@ -0,0 +1,242 @@
+# Copyright The PyTorch Lightning team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import os
+
+import pytest
+import torch
+
+from pytorch_lightning import Callback, seed_everything, Trainer
+from pytorch_lightning.accelerators import HPUAccelerator
+from pytorch_lightning.strategies.hpu_parallel import HPUParallelStrategy
+from pytorch_lightning.strategies.single_hpu import SingleHPUStrategy
+from pytorch_lightning.utilities import _HPU_AVAILABLE
+from pytorch_lightning.utilities.exceptions import MisconfigurationException
+from tests.helpers.boring_model import BoringModel
+from tests.helpers.datamodules import ClassifDataModule
+from tests.helpers.runif import RunIf
+from tests.helpers.simple_models import ClassificationModel
+
+
+@RunIf(hpu=True)
+def test_availability():
+ assert HPUAccelerator.is_available()
+
+
+@pytest.mark.skipif(_HPU_AVAILABLE, reason="test requires non-HPU machine")
+def test_fail_if_no_hpus():
+ with pytest.raises(MisconfigurationException, match="HPUAccelerator can not run on your system"):
+ Trainer(accelerator="hpu", devices=1)
+
+
+@RunIf(hpu=True)
+def test_accelerator_selected():
+ trainer = Trainer(accelerator="hpu")
+ assert isinstance(trainer.accelerator, HPUAccelerator)
+
+
+@RunIf(hpu=True)
+def test_all_stages(tmpdir, hpus):
+ """Tests all the model stages using BoringModel on HPU."""
+ model = BoringModel()
+
+ trainer = Trainer(
+ default_root_dir=tmpdir,
+ fast_dev_run=True,
+ accelerator="hpu",
+ devices=hpus,
+ precision=16,
+ )
+ trainer.fit(model)
+ trainer.validate(model)
+ trainer.test(model)
+ trainer.predict(model)
+
+
+@RunIf(hpu=True)
+def test_optimization(tmpdir):
+ seed_everything(42)
+
+ dm = ClassifDataModule(length=1024)
+ model = ClassificationModel()
+
+ trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, accelerator="hpu", devices=1)
+
+ # fit model
+ trainer.fit(model, dm)
+ assert trainer.state.finished, f"Training failed with {trainer.state}"
+ assert dm.trainer is not None
+
+ # validate
+ result = trainer.validate(datamodule=dm)
+ assert dm.trainer is not None
+ assert result[0]["val_acc"] > 0.7
+
+ # test
+ result = trainer.test(model, datamodule=dm)
+ assert dm.trainer is not None
+ test_result = result[0]["test_acc"]
+ assert test_result > 0.6
+
+ # test saved model
+ model_path = os.path.join(tmpdir, "model.pt")
+ trainer.save_checkpoint(model_path)
+
+ model = ClassificationModel.load_from_checkpoint(model_path)
+
+ trainer = Trainer(default_root_dir=tmpdir, accelerator="hpu", devices=1)
+
+ result = trainer.test(model, datamodule=dm)
+ saved_result = result[0]["test_acc"]
+ assert saved_result == test_result
+
+
+@RunIf(hpu=True)
+def test_stages_correct(tmpdir):
+ """Ensure all stages correctly are traced correctly by asserting the output for each stage."""
+
+ class StageModel(BoringModel):
+ def training_step(self, batch, batch_idx):
+ loss = super().training_step(batch, batch_idx)
+ loss = loss.get("loss")
+ # tracing requires a loss value that depends on the model.
+ # force it to be a value but ensure we use the loss.
+ loss = (loss - loss) + torch.tensor(1)
+ return {"loss": loss}
+
+ def validation_step(self, batch, batch_idx):
+ loss = super().validation_step(batch, batch_idx)
+ x = loss.get("x")
+ x = (x - x) + torch.tensor(2)
+ return {"x": x}
+
+ def test_step(self, batch, batch_idx):
+ loss = super().test_step(batch, batch_idx)
+ y = loss.get("y")
+ y = (y - y) + torch.tensor(3)
+ return {"y": y}
+
+ def predict_step(self, batch, batch_idx, dataloader_idx=None):
+ output = super().predict_step(batch, batch_idx)
+ return (output - output) + torch.tensor(4)
+
+ class TestCallback(Callback):
+ def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx) -> None:
+ assert outputs["loss"].item() == 1
+
+ def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx) -> None:
+ assert outputs["x"].item() == 2
+
+ def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx) -> None:
+ assert outputs["y"].item() == 3
+
+ def on_predict_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx) -> None:
+ assert torch.all(outputs == 4).item()
+
+ model = StageModel()
+ trainer = Trainer(
+ default_root_dir=tmpdir, fast_dev_run=True, accelerator="hpu", devices=1, callbacks=TestCallback()
+ )
+ trainer.fit(model)
+ trainer.test(model)
+ trainer.validate(model)
+ trainer.predict(model)
+
+
+@RunIf(hpu=True)
+def test_accelerator_hpu():
+
+ trainer = Trainer(accelerator="hpu", devices=1)
+ assert isinstance(trainer.accelerator, HPUAccelerator)
+ assert trainer.num_devices == 1
+
+ trainer = Trainer(accelerator="hpu")
+ assert isinstance(trainer.accelerator, HPUAccelerator)
+ assert trainer.num_devices == 8
+
+ trainer = Trainer(accelerator="auto", devices=8)
+ assert isinstance(trainer.accelerator, HPUAccelerator)
+ assert trainer.num_devices == 8
+
+
+@RunIf(hpu=True)
+def test_accelerator_hpu_with_single_device():
+
+ trainer = Trainer(accelerator="hpu", devices=1)
+
+ assert isinstance(trainer.strategy, SingleHPUStrategy)
+ assert isinstance(trainer.accelerator, HPUAccelerator)
+
+
+@RunIf(hpu=True)
+def test_accelerator_hpu_with_multiple_devices():
+
+ trainer = Trainer(accelerator="hpu", devices=8)
+
+ assert isinstance(trainer.strategy, HPUParallelStrategy)
+ assert isinstance(trainer.accelerator, HPUAccelerator)
+
+
+@RunIf(hpu=True)
+def test_accelerator_auto_with_devices_hpu():
+
+ trainer = Trainer(accelerator="auto", devices=8)
+
+ assert isinstance(trainer.strategy, HPUParallelStrategy)
+
+
+@RunIf(hpu=True)
+def test_strategy_choice_hpu_plugin():
+ trainer = Trainer(strategy=SingleHPUStrategy(device=torch.device("hpu")), accelerator="hpu", devices=1)
+ assert isinstance(trainer.strategy, SingleHPUStrategy)
+
+ trainer = Trainer(accelerator="hpu", devices=1)
+ assert isinstance(trainer.strategy, SingleHPUStrategy)
+
+
+@RunIf(hpu=True)
+def test_strategy_choice_hpu_parallel_plugin():
+ trainer = Trainer(
+ strategy=HPUParallelStrategy(parallel_devices=[torch.device("hpu")] * 8), accelerator="hpu", devices=8
+ )
+ assert isinstance(trainer.strategy, HPUParallelStrategy)
+
+ trainer = Trainer(accelerator="hpu", devices=8)
+ assert isinstance(trainer.strategy, HPUParallelStrategy)
+
+
+@RunIf(hpu=True)
+def test_devices_auto_choice_hpu():
+ trainer = Trainer(accelerator="auto", devices="auto")
+ assert trainer.num_devices == 8
+
+
+@RunIf(hpu=True)
+@pytest.mark.parametrize("hpus", [1])
+def test_inference_only(tmpdir, hpus):
+ model = BoringModel()
+
+ trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, accelerator="hpu", devices=hpus)
+ trainer.validate(model)
+ trainer.test(model)
+ trainer.predict(model)
+
+
+def test_hpu_auto_device_count():
+ assert HPUAccelerator.auto_device_count() == 8
+
+
+@RunIf(hpu=True)
+def test_hpu_unsupported_device_type():
+ with pytest.raises(MisconfigurationException, match="`devices` for `HPUAccelerator` must be int, string or None."):
+ Trainer(accelerator="hpu", devices=[1])
diff --git a/tests/conftest.py b/tests/conftest.py
index cdb1bfbd39392..b24f1dad8c61b 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -198,3 +198,19 @@ def pytest_collection_modifyitems(items):
# has `@RunIf(ipu=True)`
if marker.name == "skipif" and marker.kwargs.get("ipu")
]
+
+
+def pytest_addoption(parser):
+ parser.addoption("--hpus", action="store", type=int, default=1, help="Number of hpus 1-8")
+ parser.addoption(
+ "--hmp-bf16", action="store", type=str, default="./ops_bf16_mnist.txt", help="bf16 ops list file in hmp O1 mode"
+ )
+ parser.addoption(
+ "--hmp-fp32", action="store", type=str, default="./ops_fp32_mnist.txt", help="fp32 ops list file in hmp O1 mode"
+ )
+
+
+@pytest.fixture
+def hpus(request):
+ hpus = request.config.getoption("--hpus")
+ return hpus
diff --git a/tests/helpers/runif.py b/tests/helpers/runif.py
index f52aada92ea30..addc8aec16450 100644
--- a/tests/helpers/runif.py
+++ b/tests/helpers/runif.py
@@ -27,6 +27,7 @@
_FAIRSCALE_AVAILABLE,
_FAIRSCALE_FULLY_SHARDED_AVAILABLE,
_HOROVOD_AVAILABLE,
+ _HPU_AVAILABLE,
_IPU_AVAILABLE,
_OMEGACONF_AVAILABLE,
_RICH_AVAILABLE,
@@ -68,6 +69,7 @@ def __new__(
amp_apex: bool = False,
tpu: bool = False,
ipu: bool = False,
+ hpu: bool = False,
horovod: bool = False,
horovod_nccl: bool = False,
skip_windows: bool = False,
@@ -94,6 +96,7 @@ def __new__(
amp_apex: Require that NVIDIA/apex is installed.
tpu: Require that TPU is available.
ipu: Require that IPU is available.
+ hpu: Require that HPU is available.
horovod: Require that Horovod is installed.
horovod_nccl: Require that Horovod is installed with NCCL support.
skip_windows: Skip for Windows platform.
@@ -154,6 +157,10 @@ def __new__(
reasons.append("IPU")
kwargs["ipu"] = True
+ if hpu:
+ conditions.append(not _HPU_AVAILABLE)
+ reasons.append("HPU")
+
if horovod:
conditions.append(not _HOROVOD_AVAILABLE)
reasons.append("Horovod")
diff --git a/tests/plugins/precision/hpu/ops_bf16.txt b/tests/plugins/precision/hpu/ops_bf16.txt
new file mode 100644
index 0000000000000..53ec99c15b4ce
--- /dev/null
+++ b/tests/plugins/precision/hpu/ops_bf16.txt
@@ -0,0 +1,2 @@
+linear
+relu
diff --git a/tests/plugins/precision/hpu/ops_fp32.txt b/tests/plugins/precision/hpu/ops_fp32.txt
new file mode 100644
index 0000000000000..4509b7e58ac29
--- /dev/null
+++ b/tests/plugins/precision/hpu/ops_fp32.txt
@@ -0,0 +1 @@
+cross_entropy
diff --git a/tests/plugins/precision/hpu/test_hpu.py b/tests/plugins/precision/hpu/test_hpu.py
new file mode 100644
index 0000000000000..5701bf2dc2caa
--- /dev/null
+++ b/tests/plugins/precision/hpu/test_hpu.py
@@ -0,0 +1,96 @@
+# Copyright The PyTorch Lightning team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import Optional
+
+import pytest
+import torch
+
+from pytorch_lightning import Callback, LightningModule, Trainer
+from pytorch_lightning.plugins import HPUPrecisionPlugin
+from pytorch_lightning.strategies.single_hpu import SingleHPUStrategy
+from tests.helpers.boring_model import BoringModel
+from tests.helpers.runif import RunIf
+
+
+@pytest.fixture
+def hmp_params(request):
+ return {
+ "opt_level": "O1",
+ "verbose": False,
+ "bf16_file_path": request.config.getoption("--hmp-bf16"),
+ "fp32_file_path": request.config.getoption("--hmp-fp32"),
+ }
+
+
+@RunIf(hpu=True)
+def test_precision_plugin(hmp_params):
+ plugin = HPUPrecisionPlugin(precision="bf16", **hmp_params)
+ assert plugin.precision == "bf16"
+
+
+@RunIf(hpu=True)
+def test_mixed_precision(tmpdir, hmp_params: dict):
+ class TestCallback(Callback):
+ def setup(self, trainer: Trainer, pl_module: LightningModule, stage: Optional[str] = None) -> None:
+ assert trainer.strategy.model.precision == "bf16"
+ raise SystemExit
+
+ model = BoringModel()
+ trainer = Trainer(
+ default_root_dir=tmpdir,
+ fast_dev_run=True,
+ accelerator="hpu",
+ devices=1,
+ plugins=[HPUPrecisionPlugin(precision="bf16", **hmp_params)],
+ callbacks=TestCallback(),
+ )
+ assert isinstance(trainer.strategy, SingleHPUStrategy)
+ assert isinstance(trainer.strategy.precision_plugin, HPUPrecisionPlugin)
+ assert trainer.strategy.precision_plugin.precision == "bf16"
+ with pytest.raises(SystemExit):
+ trainer.fit(model)
+
+
+@RunIf(hpu=True)
+def test_pure_half_precision(tmpdir, hmp_params: dict):
+ class TestCallback(Callback):
+ def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None:
+ assert trainer.strategy.model.precision == 16
+ for param in trainer.strategy.model.parameters():
+ assert param.dtype == torch.float16
+ raise SystemExit
+
+ model = BoringModel()
+ model = model.half()
+ trainer = Trainer(
+ default_root_dir=tmpdir,
+ fast_dev_run=True,
+ accelerator="hpu",
+ devices=1,
+ plugins=[HPUPrecisionPlugin(precision=16, **hmp_params)],
+ callbacks=TestCallback(),
+ )
+
+ assert isinstance(trainer.strategy, SingleHPUStrategy)
+ assert isinstance(trainer.strategy.precision_plugin, HPUPrecisionPlugin)
+ assert trainer.strategy.precision_plugin.precision == 16
+
+ with pytest.raises(SystemExit):
+ trainer.fit(model)
+
+
+@RunIf(hpu=True)
+def test_unsupported_precision_plugin():
+ with pytest.raises(ValueError, match=r"accelerator='hpu', precision='mixed'\)` is not supported."):
+ HPUPrecisionPlugin(precision="mixed")