diff --git a/.azure-pipelines/hpu-tests.yml b/.azure-pipelines/hpu-tests.yml index 1be1677045a6c..d846994175f40 100644 --- a/.azure-pipelines/hpu-tests.yml +++ b/.azure-pipelines/hpu-tests.yml @@ -31,3 +31,23 @@ jobs: apt-get install -y hwinfo hwinfo --short displayName: 'Instance HW info' + + - bash: | + pip install . --requirement requirements/test.txt + displayName: 'Install dependencies' + + - bash: | + python ".azure-pipelines/run_hpu_tests.py" + displayName: 'HPU Tests in parallel' + + - bash: | + export PYTHONPATH="${PYTHONPATH}:$(pwd)" + python "pl_examples/hpu_examples/simple_mnist/mnist.py" + displayName: 'Testing: HPU examples' + + - task: PublishTestResults@2 + inputs: + testResultsFiles: 'hpu*_test-results.xml' + testRunTitle: '$(Agent.OS) - $(Build.DefinitionName) - Python $(python.version)' + condition: succeededOrFailed() + displayName: 'Publish test results' diff --git a/.azure-pipelines/run_hpu_tests.py b/.azure-pipelines/run_hpu_tests.py new file mode 100644 index 0000000000000..590c5d9c42251 --- /dev/null +++ b/.azure-pipelines/run_hpu_tests.py @@ -0,0 +1,148 @@ +"""This file is called from the hpu-tests.yml pipeline. + +The following script run the hpu tests in parallel. +Tests run are: +1. test_inference_only is run on four cards +2. test_all_stages on two cards +3. complete hpu tests using one card +4. complete hpu tests using eight cards. +""" +import itertools +import subprocess +import sys + +HPU_TESTS_DICTIONARY = { + "hpu1_test": "python -m coverage run --source pytorch_lightning -m pytest -sv tests/accelerators/test_hpu.py \ + --forked \ + --junitxml=hpu1_test-results.xml", + "hpu2_test": "python -m coverage run --source pytorch_lightning -m pytest -sv tests/accelerators/test_hpu.py \ + -k test_all_stages \ + --hpus 2 \ + --verbose \ + --capture=no \ + --forked \ + --junitxml=hpu2_test-results.xml", + "hpu4_test": "python -m coverage run --source pytorch_lightning -m pytest -sv tests/accelerators/test_hpu.py \ + -k test_inference_only \ + --hpus 4 \ + --capture=no \ + --verbose \ + --forked \ + --junitxml=hpu4_test-results.xml", + "hpu8_test": "python -m coverage run --source pytorch_lightning -m pytest -sv tests/accelerators/test_hpu.py \ + --forked \ + --hpus 8 \ + --junitxml=hpu8_test-results.xml", + "hpu1_precision_test": "python -m coverage run --source pytorch_lightning -m pytest -sv tests/plugins/precision/hpu/test_hpu.py \ + --hmp-bf16 'tests/plugins/precision/hpu/ops_bf16.txt' \ + --hmp-fp32 'tests/plugins/precision/hpu/ops_fp32.txt' \ + --forked \ + --junitxml=hpu1_precision_test-results.xml", +} + +HPU1_TEST = HPU_TESTS_DICTIONARY["hpu1_test"] +HPU2_TEST = HPU_TESTS_DICTIONARY["hpu2_test"] +HPU4_TEST = HPU_TESTS_DICTIONARY["hpu4_test"] +HPU8_TEST = HPU_TESTS_DICTIONARY["hpu8_test"] +HPU1_PRECISION_TEST = HPU_TESTS_DICTIONARY["hpu1_precision_test"] + +PARALLEL_HPU_TESTS_EXECUTION = [[HPU4_TEST, HPU1_TEST], [HPU2_TEST, HPU1_TEST], [HPU8_TEST], [HPU1_PRECISION_TEST]] +TIMEOUT = 60 # seconds +TIMEOUT_EXIT_CODE = -9 + + +def run_hpu_tests_parallel(timeout=TIMEOUT): + """This function is called to run the HPU tests in parallel. + + We run the tests in sub process to utilize all the eight cards available in the DL1 instance + Considering the max time taken to run the HPU tests as 60 seconds, we kill the process if the time taken exceeds. + + Args: + timeout: The threshold time to run the HPU tests in parallel. + An exception is logged if the threshold timeout gets expired. + TIMEOUT_EXIT_CODE will be returned as -9 in case of timeout, + 0 in case of success and 4 in case of failure. + + Return: + The list of exit status of the HPU tests that were run in the subprocess. + Here, the exit_status 0 means the test run is successful. exit_status 1 means the test run is failed. + """ + exit_status = [] + with open("stdout_log.txt", "w") as stdout_log, open("error_log.txt", "w") as error_log: + for hpu_tests in PARALLEL_HPU_TESTS_EXECUTION: + process_list = [ + subprocess.Popen( + each_hpu_test, shell=True, stdout=stdout_log, stderr=error_log, universal_newlines=True + ) + for each_hpu_test in hpu_tests + ] + for process in process_list: + try: + exit_status.append(process.wait(timeout=TIMEOUT)) + except subprocess.TimeoutExpired as e: + print(e) + print("Killing the process....") + process.kill() + exit_status.append(TIMEOUT_EXIT_CODE) + return exit_status + + +def zip_cmd_exitcode(exit_status): + """This function is called to zip the tests that were executed with the exit status of the test. + + Args: + exit_status: The returned exit_status after executing run_hpu_tests_parallel(). + + Return: + A list of hpu tests called and their exit status. + """ + status_list = [] + status_list = list(zip(list(itertools.chain(*PARALLEL_HPU_TESTS_EXECUTION)), exit_status)) + return status_list + + +def print_logs(filename): + """This function is called to read the file and print the logs. + + Args: + filename: Provide the log filename that need to be print on the console. + """ + with open(filename) as f: + print(f.read()) + + +def print_subprocess_logs_and_return_status(exit_status): + """This function is called to print the logs of subprocess stdout and stderror and return the status of test + execution. + + Args: + exit_status: The returned exit_status after executing run_hpu_tests_parallel(). + + Return: + Based on the exit status of the HPU tests, we return success or failure to the main method. + """ + if all(v == 0 for v in exit_status): + print("All HPU tests passed") + file_name = "stdout_log.txt" + print_logs(file_name) + return 0 + else: + print("HPU tests are failing") + print("Printing stdout_log.txt...") + file_name = "stdout_log.txt" + print_logs(file_name) + print("Printing error_log.txt...") + file_name = "error_log.txt" + print_logs(file_name) + return 1 + + +def main(): + exit_status = run_hpu_tests_parallel(timeout=TIMEOUT) + status_list = zip_cmd_exitcode(exit_status) + print("HPU Tests executed and their exit status:", status_list) + return print_subprocess_logs_and_return_status(exit_status) + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/CHANGELOG.md b/CHANGELOG.md index 2ab43e60a26a4..76e869240a87b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -167,6 +167,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added `AcceleratorRegistry` ([#12180](https://github.com/PyTorchLightning/pytorch-lightning/pull/12180)) +- Added support for Habana Accelerator (HPU) ([#11808](https://github.com/PyTorchLightning/pytorch-lightning/pull/11808)) + + ### Changed - Drop PyTorch 1.7 support ([#12191](https://github.com/PyTorchLightning/pytorch-lightning/pull/12191)) diff --git a/docs/source/accelerators/hpu.rst b/docs/source/accelerators/hpu.rst new file mode 100644 index 0000000000000..fd7bd310ffc43 --- /dev/null +++ b/docs/source/accelerators/hpu.rst @@ -0,0 +1,124 @@ +.. _hpu: + +Habana Gaudi AI Processor (HPU) +=============================== + +Lightning supports `Habana Gaudi AI Processor (HPU) `__, for accelerating Deep Learning training workloads. + +HPU Terminology +--------------- + +Habana® Gaudi® AI training processors are built on a heterogeneous architecture with a cluster of fully programmable Tensor Processing Cores (TPC) along with its associated development tools and libraries, and a configurable Matrix Math engine. + +The TPC core is a VLIW SIMD processor with an instruction set and hardware tailored to serve training workloads efficiently. +The Gaudi memory architecture includes on-die SRAM and local memories in each TPC and, +Gaudi is the first DL training processor that has integrated RDMA over Converged Ethernet (RoCE v2) engines on-chip. + +On the software side, the PyTorch Habana bridge interfaces between the framework and SynapseAI software stack to enable the execution of deep learning models on the Habana Gaudi device. + +Gaudi offers a substantial price/performance advantage -- so you get to do more deep learning training while spending less. + +For more information, check out `Gaudi Architecture `__ and `Gaudi Developer Docs `__. + +How to access HPUs +------------------ + +To use HPUs, you must have access to a system with HPU devices. +You can either use `Gaudi-based AWS EC2 DL1 instances `__ or `Supermicro X12 Gaudi server `__ to get access to HPUs. + +Check out the `Getting Started Guide with AWS and Habana `__. + +Training with HPUs +------------------ + +To enable PyTorch Lightning to utilize the HPU accelerator, simply provide ``accelerator="hpu"`` parameter to the Trainer class. + +.. code-block:: python + + trainer = Trainer(accelerator="hpu") + +Passing ``devices=1`` and ``accelerator="hpu"`` to the Trainer class enables the Habana accelerator for single Gaudi training. + +.. code-block:: python + + trainer = Trainer(devices=1, accelerator="hpu") + +The ``devices=8`` and ``accelerator="hpu"`` parameters to the Trainer class enables the Habana accelerator for distributed training with 8 Gaudis. +It uses :class:`~pytorch_lightning.strategies.hpu_parallel.HPUParallelStrategy` internally which is based on DDP strategy with the addition of Habana's collective communication library (HCCL) to support scale-up within a node and scale-out across multiple nodes. + +.. code-block:: python + + trainer = Trainer(devices=8, accelerator="hpu") + +.. note:: + If the ``devices`` flag is not defined, it will assume ``devices`` to be ``"auto"`` and select 8 Gaudi devices for :class:`~pytorch_lightning.accelerators.hpu.HPUAccelerator`. + + +Mixed Precision Plugin +---------------------- + +Lightning also allows mixed precision training with HPUs. +By default, HPU training will use 32-bit precision. To enable mixed precision, set the ``precision`` flag. + +.. code-block:: python + + trainer = Trainer(devices=1, accelerator="hpu", precision=16) + + +Enabling Mixed Precision Options +-------------------------------- + +Internally, :class:`~pytorch_lightning.plugins.precision.hpu.HPUPrecisionPlugin` uses the Habana Mixed Precision (HMP) package to enable mixed precision training. + +You can execute the ops in FP32 or BF16 precision. The HMP package modifies the Python operators to add the appropriate cast operations for the arguments before execution. +The default settings enable users to enable mixed precision training with minimal code easily. + +In addition to the default settings in HMP, users also have the option of overriding these defaults and providing their +BF16 and FP32 operator lists by passing them as parameter to :class:`~pytorch_lightning.plugins.precision.hpu.HPUPrecisionPlugin`. + +The below snippet shows an example model using MNIST with a single Habana Gaudi device and making use of HMP by overriding the default parameters. +This enables advanced users to provide their own BF16 and FP32 operator list instead of using the HMP defaults. + +.. code-block:: python + + import pytorch_lightning as pl + from pytorch_lightning.plugins import HPUPrecisionPlugin + + # Initialize a trainer with HPU accelerator for HPU strategy for single device, + # with mixed precision using overidden HMP settings + trainer = pl.Trainer( + accelerator="hpu", + devices=1, + # Optional Habana mixed precision params to be set + # Checkout `pl_examples/hpu_examples/simple_mnist/ops_bf16_mnist.txt` for the format + plugins=[ + HPUPrecisionPlugin( + precision=16, + opt_level="O1", + verbose=False, + bf16_file_path="ops_bf16_mnist.txt", + fp32_file_path="ops_fp32_mnist.txt", + ) + ], + ) + + # Init our model + model = LitClassifier() + # Init the data + dm = MNISTDataModule(batch_size=batch_size) + + # Train the model ⚡ + trainer.fit(model, datamodule=dm) + +For more details, please refer to `PyTorch Mixed Precision Training on Gaudi `__. + +---------------- + +.. _known-limitations_hpu: + +Known limitations +----------------- + +* Multiple optimizers are not supported. +* `Habana dataloader `__ is not supported. +* :class:`~pytorch_lightning.callbacks.device_stats_monitor.DeviceStatsMonitor` is not supported. diff --git a/docs/source/api_references.rst b/docs/source/api_references.rst index b0bcf4e74ff8b..23a43990fed39 100644 --- a/docs/source/api_references.rst +++ b/docs/source/api_references.rst @@ -16,6 +16,7 @@ Accelerator API Accelerator CPUAccelerator GPUAccelerator + HPUAccelerator IPUAccelerator TPUAccelerator @@ -59,9 +60,11 @@ Strategy API DataParallelStrategy DeepSpeedStrategy HorovodStrategy + HPUParallelStrategy IPUStrategy ParallelStrategy SingleDeviceStrategy + SingleHPUStrategy SingleTPUStrategy Strategy TPUSpawnStrategy @@ -198,6 +201,7 @@ Precision Plugins DeepSpeedPrecisionPlugin DoublePrecisionPlugin FullyShardedNativeMixedPrecisionPlugin + HPUPrecisionPlugin IPUPrecisionPlugin MixedPrecisionPlugin NativeMixedPrecisionPlugin @@ -234,6 +238,7 @@ Checkpoint IO Plugins :template: classtemplate.rst CheckpointIO + HPUCheckpointIO TorchCheckpointIO XLACheckpointIO diff --git a/docs/source/extensions/accelerator.rst b/docs/source/extensions/accelerator.rst index 762f2b6e57a90..0d78371a0e7ed 100644 --- a/docs/source/extensions/accelerator.rst +++ b/docs/source/extensions/accelerator.rst @@ -15,6 +15,7 @@ Currently there are accelerators for: - GPU - TPU - IPU +- HPU Each Accelerator gets two plugins upon initialization: One to handle differences from the training routine and one to handle different precisions. @@ -58,5 +59,6 @@ Accelerator API Accelerator CPUAccelerator GPUAccelerator - TPUAccelerator + HPUAccelerator IPUAccelerator + TPUAccelerator diff --git a/docs/source/extensions/plugins.rst b/docs/source/extensions/plugins.rst index 252fb47570fc4..3bfa7ad24b29c 100644 --- a/docs/source/extensions/plugins.rst +++ b/docs/source/extensions/plugins.rst @@ -61,17 +61,18 @@ Precision Plugins :nosignatures: :template: classtemplate.rst - PrecisionPlugin - MixedPrecisionPlugin - NativeMixedPrecisionPlugin - ShardedNativeMixedPrecisionPlugin ApexMixedPrecisionPlugin DeepSpeedPrecisionPlugin - TPUPrecisionPlugin - TPUBf16PrecisionPlugin DoublePrecisionPlugin FullyShardedNativeMixedPrecisionPlugin + HPUPrecisionPlugin IPUPrecisionPlugin + MixedPrecisionPlugin + NativeMixedPrecisionPlugin + PrecisionPlugin + ShardedNativeMixedPrecisionPlugin + TPUBf16PrecisionPlugin + TPUPrecisionPlugin Cluster Environments diff --git a/docs/source/extensions/strategy.rst b/docs/source/extensions/strategy.rst index e85b719e8566c..7c5596c7362ea 100644 --- a/docs/source/extensions/strategy.rst +++ b/docs/source/extensions/strategy.rst @@ -108,9 +108,11 @@ Built-In Training Strategies DataParallelStrategy DeepSpeedStrategy HorovodStrategy + HPUParallelStrategy IPUStrategy ParallelStrategy SingleDeviceStrategy + SingleHPUStrategy SingleTPUStrategy Strategy TPUSpawnStrategy diff --git a/docs/source/index.rst b/docs/source/index.rst index 4e7543b51f4a7..d02eff4fb225a 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -88,6 +88,7 @@ Welcome to PyTorch Lightning accelerators/gpu accelerators/tpu accelerators/ipu + accelerators/hpu .. toctree:: :maxdepth: 1 diff --git a/pl_examples/hpu_examples/simple_mnist/mnist.py b/pl_examples/hpu_examples/simple_mnist/mnist.py new file mode 100644 index 0000000000000..a5d4b47d6b829 --- /dev/null +++ b/pl_examples/hpu_examples/simple_mnist/mnist.py @@ -0,0 +1,75 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch +from jsonargparse import lazy_instance +from torch.nn import functional as F + +import pytorch_lightning as pl +from pl_examples.basic_examples.mnist_datamodule import MNISTDataModule +from pytorch_lightning.plugins import HPUPrecisionPlugin +from pytorch_lightning.utilities.cli import LightningCLI + + +class LitClassifier(pl.LightningModule): + def __init__(self): + super().__init__() + self.l1 = torch.nn.Linear(28 * 28, 10) + + def forward(self, x): + return torch.relu(self.l1(x.view(x.size(0), -1))) + + def training_step(self, batch, batch_idx): + x, y = batch + loss = F.cross_entropy(self(x), y) + return loss + + def validation_step(self, batch, batch_idx): + x, y = batch + probs = self(x) + acc = self.accuracy(probs, y) + self.log("val_acc", acc) + + def test_step(self, batch, batch_idx): + x, y = batch + logits = self(x) + acc = self.accuracy(logits, y) + self.log("test_acc", acc) + + @staticmethod + def accuracy(logits, y): + acc = torch.sum(torch.eq(torch.argmax(logits, -1), y).to(torch.float32)) / len(y) + return acc + + def configure_optimizers(self): + return torch.optim.Adam(self.parameters(), lr=0.02) + + +if __name__ == "__main__": + cli = LightningCLI( + LitClassifier, + MNISTDataModule, + trainer_defaults={ + "accelerator": "hpu", + "devices": 1, + "max_epochs": 1, + "plugins": lazy_instance(HPUPrecisionPlugin, precision=16), + }, + run=False, + save_config_overwrite=True, + ) + + # Run the model ⚡ + cli.trainer.fit(cli.model, datamodule=cli.datamodule) + cli.trainer.validate(cli.model, datamodule=cli.datamodule) + cli.trainer.test(cli.model, datamodule=cli.datamodule) diff --git a/pl_examples/hpu_examples/simple_mnist/ops_bf16_mnist.txt b/pl_examples/hpu_examples/simple_mnist/ops_bf16_mnist.txt new file mode 100644 index 0000000000000..53ec99c15b4ce --- /dev/null +++ b/pl_examples/hpu_examples/simple_mnist/ops_bf16_mnist.txt @@ -0,0 +1,2 @@ +linear +relu diff --git a/pl_examples/hpu_examples/simple_mnist/ops_fp32_mnist.txt b/pl_examples/hpu_examples/simple_mnist/ops_fp32_mnist.txt new file mode 100644 index 0000000000000..4509b7e58ac29 --- /dev/null +++ b/pl_examples/hpu_examples/simple_mnist/ops_fp32_mnist.txt @@ -0,0 +1 @@ +cross_entropy diff --git a/pytorch_lightning/accelerators/__init__.py b/pytorch_lightning/accelerators/__init__.py index fe9ae0a120cfb..1ab90e025b087 100644 --- a/pytorch_lightning/accelerators/__init__.py +++ b/pytorch_lightning/accelerators/__init__.py @@ -13,6 +13,7 @@ from pytorch_lightning.accelerators.accelerator import Accelerator # noqa: F401 from pytorch_lightning.accelerators.cpu import CPUAccelerator # noqa: F401 from pytorch_lightning.accelerators.gpu import GPUAccelerator # noqa: F401 +from pytorch_lightning.accelerators.hpu import HPUAccelerator # noqa: F401 from pytorch_lightning.accelerators.ipu import IPUAccelerator # noqa: F401 from pytorch_lightning.accelerators.registry import AcceleratorRegistry, call_register_accelerators # noqa: F401 from pytorch_lightning.accelerators.tpu import TPUAccelerator # noqa: F401 diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index 526cec3e47319..5fe6b53dd54b5 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -28,6 +28,7 @@ class Accelerator(ABC): - GPU - TPU - IPU + - HPU """ def setup_environment(self, root_device: torch.device) -> None: diff --git a/pytorch_lightning/accelerators/hpu.py b/pytorch_lightning/accelerators/hpu.py new file mode 100644 index 0000000000000..76fdb02b307b8 --- /dev/null +++ b/pytorch_lightning/accelerators/hpu.py @@ -0,0 +1,69 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Dict, List, Optional, Union + +import torch + +from pytorch_lightning.accelerators.accelerator import Accelerator +from pytorch_lightning.utilities import _HPU_AVAILABLE, device_parser +from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.rank_zero import rank_zero_debug + + +class HPUAccelerator(Accelerator): + """Accelerator for HPU devices.""" + + def setup_environment(self, root_device: torch.device) -> None: + """ + Raises: + MisconfigurationException: + If the selected device is not HPU. + """ + super().setup_environment(root_device) + if root_device.type != "hpu": + raise MisconfigurationException(f"Device should be HPU, got {root_device} instead.") + + def get_device_stats(self, device: Union[str, torch.device]) -> Dict[str, Any]: + """HPU device stats aren't supported yet.""" + rank_zero_debug("HPU device stats aren't supported yet.") + return {} + + @staticmethod + def parse_devices(devices: Union[int, str, List[int]]) -> Optional[int]: + """Accelerator device parsing logic.""" + return device_parser.parse_hpus(devices) + + @staticmethod + def get_parallel_devices(devices: int) -> List[torch.device]: + """Gets parallel devices for the Accelerator.""" + return [torch.device("hpu")] * devices + + @staticmethod + def auto_device_count() -> int: + """Get the devices when set to auto.""" + # TODO(@kaushikb11): Update this when api is exposed by the Habana team + return 8 + + @staticmethod + def is_available() -> bool: + return _HPU_AVAILABLE + + @classmethod + def register_accelerators(cls, accelerator_registry: Dict) -> None: + accelerator_registry.register( + "hpu", + cls, + description=f"{cls.__class__.__name__}", + ) diff --git a/pytorch_lightning/overrides/torch_distributed.py b/pytorch_lightning/overrides/torch_distributed.py new file mode 100644 index 0000000000000..9c70a2867b429 --- /dev/null +++ b/pytorch_lightning/overrides/torch_distributed.py @@ -0,0 +1,168 @@ +# type: ignore + +import io +import logging +import os +import pickle + +import torch +from torch._C._distributed_c10d import ProcessGroup + +_pickler = pickle.Pickler +_unpickler = pickle.Unpickler + + +logger = logging.getLogger(__name__) + +if torch.distributed.is_available(): + from torch.distributed import Backend, broadcast, get_backend, get_rank, GroupMember + +# The code underneath is taken from PyTorch `torch/distributed/distributed_c10d.py` +# the distributed backend and tensor type updates for habana backend is done here before broadcast + + +# Taken from https://github.com/pytorch/pytorch/blob/3466c1b6901f06a563b8cbfa3c942fa50bda835b/torch/distributed/distributed_c10d.py#L267 # noqa: E501 +def _rank_not_in_group(group: ProcessGroup): + """Helper that checks if the current process's rank is not in a given group.""" + if group is None: + return False + return group == GroupMember.NON_GROUP_MEMBER + + +# Taken from https://github.com/pytorch/pytorch/blob/3466c1b6901f06a563b8cbfa3c942fa50bda835b/torch/distributed/distributed_c10d.py#L1551 # noqa: E501 +def _object_to_tensor(obj): + f = io.BytesIO() + _pickler(f).dump(obj) + byte_storage = torch.ByteStorage.from_buffer(f.getvalue()) + # Do not replace `torch.ByteTensor` or `torch.LongTensor` with torch.tensor and specifying dtype. + # Otherwise, it will casue 100X slowdown. + # See: https://github.com/pytorch/pytorch/issues/65696 + byte_tensor = torch.ByteTensor(byte_storage) + local_size = torch.LongTensor([byte_tensor.numel()]) + return byte_tensor, local_size + + +# Taken from https://github.com/pytorch/pytorch/blob/3466c1b6901f06a563b8cbfa3c942fa50bda835b/torch/distributed/distributed_c10d.py#L1563 # noqa: E501 +def _tensor_to_object(tensor, tensor_size): + buf = tensor.numpy().tobytes()[:tensor_size] + return _unpickler(io.BytesIO(buf)).load() + + +def _broadcast_object_list(object_list, src=0, group=None, device=None): + """Broadcasts picklable objects in ``object_list`` to the whole group. Similar to :func:`broadcast`, but Python + objects can be passed in. Note that all objects in ``object_list`` must be picklable in order to be + broadcasted. + + Args: + object_list: List of input objects to broadcast. + Each object must be picklable. Only objects on the ``src`` rank will + be broadcast, but each rank must provide lists of equal sizes. + src: Source rank from which to broadcast ``object_list``. + group: The process group to work on. If None, + the default process group will be used. Default is ``None``. + device: If not None, the objects are + serialized and converted to tensors which are moved to the + ``device`` before broadcasting. Default is ``None``. + + Returns: + ``None``. If rank is part of the group, ``object_list`` will contain the + broadcasted objects from ``src`` rank. + + .. note:: For NCCL-based processed groups, internal tensor representations + of objects must be moved to the GPU device before communication takes + place. In this case, the device used is given by + ``torch.cuda.current_device()`` and it is the user's responsiblity to + ensure that this is set so that each rank has an individual GPU, via + ``torch.cuda.set_device()``. + + .. note:: Note that this API differs slightly from the :func:`all_gather` + collective since it does not provide an ``async_op`` handle and thus + will be a blocking call. + + .. warning:: + :func:`broadcast_object_list` uses ``pickle`` module implicitly, which + is known to be insecure. It is possible to construct malicious pickle + data which will execute arbitrary code during unpickling. Only call this + function with data you trust. + """ + if _rank_not_in_group(group): + return + + my_rank = get_rank() + # Serialize object_list elements to tensors on src rank. + if my_rank == src: + tensor_list, size_list = zip(*[_object_to_tensor(obj) for obj in object_list]) + object_sizes_tensor = torch.cat(size_list) + else: + object_sizes_tensor = torch.empty(len(object_list), dtype=torch.long) + + # Current device selection. + # To preserve backwards compatibility, ``device`` is default to ``None`` + # in which case we run current logic of device selection, i.e. + # ``current_device`` is CUDA if backend is NCCL otherwise CPU device. In the + # case it is not ``None`` we move the size and object tensors to be + # broadcasted to this device. + group_backend = get_backend(group) + is_nccl_backend = group_backend == Backend.NCCL + is_hpu_backend = os.environ.get("HCCL_DISTRIBUTED_BACKEND") == "1" + current_device = None + if device is not None: + if is_nccl_backend and device.type != "cuda": + raise ValueError("device type must be cuda for nccl backend") + current_device = device + else: + current_device = torch.device("cpu") + if is_nccl_backend: + # See note about using torch.cuda.current_device() here in + # docstring. We cannot simply use my_rank since rank == device is + # not necessarily true. + current_device = torch.device("cuda", torch.cuda.current_device()) + if is_nccl_backend: + object_sizes_tensor = object_sizes_tensor.to(current_device) + + elif is_hpu_backend: + current_device = torch.device("hpu") + # Workaround: HPU doesn't not support long tensors for collectives + if (object_sizes_tensor.type() == "torch.LongTensor") or (object_sizes_tensor.type() == "torch.hpu.LongTensor"): + object_sizes_tensor = object_sizes_tensor.int() + else: + print("unhandled hpu object_sizes_tensor type :: ", object_sizes_tensor.type()) + object_sizes_tensor = object_sizes_tensor.to(current_device) + + # Broadcast object sizes + broadcast(object_sizes_tensor, src=src, group=group) + + # Concatenate and broadcast serialized object tensors + if my_rank == src: + object_tensor = torch.cat(tensor_list) + else: + object_tensor = torch.empty( + torch.sum(object_sizes_tensor).int().item(), + dtype=torch.uint8, + ) + + if is_nccl_backend or is_hpu_backend: + object_tensor = object_tensor.to(current_device) + + broadcast(object_tensor, src=src, group=group) + # Deserialize objects using their stored sizes. + offset = 0 + if my_rank != src: + for i, obj_size in enumerate(object_sizes_tensor): + obj_view = object_tensor[offset : offset + obj_size] + obj_view = obj_view.type(torch.uint8) + if obj_view.device != torch.device("cpu"): + obj_view = obj_view.cpu() + offset += obj_size + object_list[i] = _tensor_to_object(obj_view, obj_size) + + +if not torch.distributed.is_available(): + # avoid failures on early PyTorch versions for Windows where + # not all functions used in `broadcast_object_list` are available. + def _broadcast_noop(obj, *_, **__): + return obj + + broadcast_object_list = _broadcast_noop +else: + broadcast_object_list = _broadcast_object_list diff --git a/pytorch_lightning/plugins/__init__.py b/pytorch_lightning/plugins/__init__.py index 2e4e0da6b2755..0f1c4ca85ed5a 100644 --- a/pytorch_lightning/plugins/__init__.py +++ b/pytorch_lightning/plugins/__init__.py @@ -2,6 +2,7 @@ from pytorch_lightning.plugins.environments import ClusterEnvironment from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO +from pytorch_lightning.plugins.io.hpu_plugin import HPUCheckpointIO from pytorch_lightning.plugins.io.torch_plugin import TorchCheckpointIO from pytorch_lightning.plugins.io.xla_plugin import XLACheckpointIO from pytorch_lightning.plugins.layer_sync import LayerSync, NativeSyncBatchNorm @@ -9,6 +10,7 @@ from pytorch_lightning.plugins.precision.deepspeed import DeepSpeedPrecisionPlugin from pytorch_lightning.plugins.precision.double import DoublePrecisionPlugin from pytorch_lightning.plugins.precision.fully_sharded_native_amp import FullyShardedNativeMixedPrecisionPlugin +from pytorch_lightning.plugins.precision.hpu import HPUPrecisionPlugin from pytorch_lightning.plugins.precision.ipu import IPUPrecisionPlugin from pytorch_lightning.plugins.precision.native_amp import NativeMixedPrecisionPlugin from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin @@ -39,6 +41,7 @@ "CheckpointIO", "TorchCheckpointIO", "XLACheckpointIO", + "HPUCheckpointIO", "ApexMixedPrecisionPlugin", "DataParallelPlugin", "DDP2Plugin", @@ -51,6 +54,7 @@ "HorovodPlugin", "IPUPlugin", "IPUPrecisionPlugin", + "HPUPrecisionPlugin", "NativeMixedPrecisionPlugin", "PrecisionPlugin", "ShardedNativeMixedPrecisionPlugin", diff --git a/pytorch_lightning/plugins/io/__init__.py b/pytorch_lightning/plugins/io/__init__.py index 1b14eee6ec4f2..abd196eb2b1e3 100644 --- a/pytorch_lightning/plugins/io/__init__.py +++ b/pytorch_lightning/plugins/io/__init__.py @@ -12,5 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO # noqa: F401 +from pytorch_lightning.plugins.io.hpu_plugin import HPUCheckpointIO # noqa: F401 from pytorch_lightning.plugins.io.torch_plugin import TorchCheckpointIO # noqa: F401 from pytorch_lightning.plugins.io.xla_plugin import XLACheckpointIO # noqa: F401 diff --git a/pytorch_lightning/plugins/io/hpu_plugin.py b/pytorch_lightning/plugins/io/hpu_plugin.py new file mode 100644 index 0000000000000..c72d1d9fcd112 --- /dev/null +++ b/pytorch_lightning/plugins/io/hpu_plugin.py @@ -0,0 +1,52 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from typing import Any, Dict, Optional + +import torch + +from pytorch_lightning.plugins.io.torch_plugin import TorchCheckpointIO +from pytorch_lightning.utilities.apply_func import move_data_to_device +from pytorch_lightning.utilities.cloud_io import atomic_save, get_filesystem +from pytorch_lightning.utilities.types import _PATH + + +class HPUCheckpointIO(TorchCheckpointIO): + """CheckpointIO to save checkpoints for HPU training strategies.""" + + def save_checkpoint(self, checkpoint: Dict[str, Any], path: _PATH, storage_options: Optional[Any] = None) -> None: + """Save model/training states as a checkpoint file through state-dump and file-write. + + Args: + checkpoint: dict containing model and trainer state + path: write-target path + storage_options: not used in ``XLACheckpointIO.save_checkpoint`` + + Raises: + TypeError: + If ``storage_options`` arg is passed in + """ + if storage_options is not None: + raise TypeError( + "`Trainer.save_checkpoint(..., storage_options=...)` with `storage_options` arg" + f" is not supported for `{self.__class__.__name__}`. Please implement your custom `CheckpointIO`" + " to define how you'd like to use `storage_options`." + ) + fs = get_filesystem(path) + fs.makedirs(os.path.dirname(path), exist_ok=True) + + checkpoint = move_data_to_device(checkpoint, torch.device("cpu")) + # write the checkpoint dictionary to the provided path + atomic_save(checkpoint, path) diff --git a/pytorch_lightning/plugins/precision/__init__.py b/pytorch_lightning/plugins/precision/__init__.py index b407e47ca9337..4bc29c1be1864 100644 --- a/pytorch_lightning/plugins/precision/__init__.py +++ b/pytorch_lightning/plugins/precision/__init__.py @@ -1,9 +1,23 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from pytorch_lightning.plugins.precision.apex_amp import ApexMixedPrecisionPlugin # noqa: F401 from pytorch_lightning.plugins.precision.deepspeed import DeepSpeedPrecisionPlugin # noqa: F401 from pytorch_lightning.plugins.precision.double import DoublePrecisionPlugin # noqa: F401 from pytorch_lightning.plugins.precision.fully_sharded_native_amp import ( # noqa: F401 FullyShardedNativeMixedPrecisionPlugin, ) +from pytorch_lightning.plugins.precision.hpu import HPUPrecisionPlugin # noqa: F401 from pytorch_lightning.plugins.precision.ipu import IPUPrecisionPlugin # noqa: F401 from pytorch_lightning.plugins.precision.mixed import MixedPrecisionPlugin # noqa: F401 from pytorch_lightning.plugins.precision.native_amp import NativeMixedPrecisionPlugin # noqa: F401 diff --git a/pytorch_lightning/plugins/precision/hpu.py b/pytorch_lightning/plugins/precision/hpu.py new file mode 100644 index 0000000000000..3c02d82a2de10 --- /dev/null +++ b/pytorch_lightning/plugins/precision/hpu.py @@ -0,0 +1,55 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Optional, Union + +from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin +from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.imports import _HPU_AVAILABLE + +if _HPU_AVAILABLE: + from habana_frameworks.torch.hpex import hmp + + +class HPUPrecisionPlugin(PrecisionPlugin): + """Plugin that enables bfloat/half support on HPUs. + + Args: + precision: The precision to use. + opt_level: Choose optimization level for hmp. + bf16_file_path: Path to bf16 ops list in hmp O1 mode. + fp32_file_path: Path to fp32 ops list in hmp O1 mode. + verbose: Enable verbose mode for hmp. + """ + + def __init__( + self, + precision: Union[str, int], + opt_level: str = "O2", + bf16_file_path: Optional[str] = None, + fp32_file_path: Optional[str] = None, + verbose: bool = False, + ) -> None: + if not _HPU_AVAILABLE: + raise MisconfigurationException("HPU precision plugin requires HPU devices.") + supported_precision_values = (16, 32, "bf16") + if precision not in supported_precision_values: + raise ValueError( + f"`Trainer(accelerator='hpu', precision={precision!r})` is not supported." + f" `precision` must be one of: {supported_precision_values}." + ) + super().__init__() + self.precision = precision + hmp.convert( + opt_level=opt_level, bf16_file_path=bf16_file_path, fp32_file_path=fp32_file_path, isVerbose=verbose + ) diff --git a/pytorch_lightning/strategies/__init__.py b/pytorch_lightning/strategies/__init__.py index a4cd57a50ac1d..38a2b466e57e9 100644 --- a/pytorch_lightning/strategies/__init__.py +++ b/pytorch_lightning/strategies/__init__.py @@ -19,11 +19,13 @@ from pytorch_lightning.strategies.dp import DataParallelStrategy # noqa: F401 from pytorch_lightning.strategies.fully_sharded import DDPFullyShardedStrategy # noqa: F401 from pytorch_lightning.strategies.horovod import HorovodStrategy # noqa: F401 +from pytorch_lightning.strategies.hpu_parallel import HPUParallelStrategy # noqa: F401 from pytorch_lightning.strategies.ipu import IPUStrategy # noqa: F401 from pytorch_lightning.strategies.parallel import ParallelStrategy # noqa: F401 from pytorch_lightning.strategies.sharded import DDPShardedStrategy # noqa: F401 from pytorch_lightning.strategies.sharded_spawn import DDPSpawnShardedStrategy # noqa: F401 from pytorch_lightning.strategies.single_device import SingleDeviceStrategy # noqa: F401 +from pytorch_lightning.strategies.single_hpu import SingleHPUStrategy # noqa: F401 from pytorch_lightning.strategies.single_tpu import SingleTPUStrategy # noqa: F401 from pytorch_lightning.strategies.strategy import Strategy # noqa: F401 from pytorch_lightning.strategies.strategy_registry import call_register_strategies, StrategyRegistry # noqa: F401 diff --git a/pytorch_lightning/strategies/hpu_parallel.py b/pytorch_lightning/strategies/hpu_parallel.py new file mode 100644 index 0000000000000..562a841b89510 --- /dev/null +++ b/pytorch_lightning/strategies/hpu_parallel.py @@ -0,0 +1,132 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +import os +from typing import Dict, List, Optional + +import torch +import torch.distributed + +import pytorch_lightning as pl +from pytorch_lightning.overrides import LightningDistributedModule +from pytorch_lightning.overrides.torch_distributed import broadcast_object_list +from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO +from pytorch_lightning.plugins.io.hpu_plugin import HPUCheckpointIO +from pytorch_lightning.plugins.precision import PrecisionPlugin +from pytorch_lightning.strategies.ddp import DDPStrategy +from pytorch_lightning.utilities.distributed import group as _group +from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.imports import _HPU_AVAILABLE, _TORCH_LESSER_EQUAL_1_10_2 + +if _HPU_AVAILABLE: + import habana_frameworks.torch.core.hccl # noqa: F401 + from habana_frameworks.torch.utils.library_loader import load_habana_module + +log = logging.getLogger(__name__) + + +class HPUParallelStrategy(DDPStrategy): + """Strategy for distributed training on multiple HPU devices.""" + + strategy_name = "hpu_parallel" + + def __init__( + self, + accelerator: Optional["pl.accelerators.accelerator.Accelerator"] = None, + parallel_devices: Optional[List[torch.device]] = None, + checkpoint_io: Optional[CheckpointIO] = None, + precision_plugin: Optional[PrecisionPlugin] = None, + process_group_backend: Optional[str] = "hccl", + ) -> None: + + if not _HPU_AVAILABLE: + raise MisconfigurationException("`HPUParallelStrategy` requires HPU devices to run") + + super().__init__( + accelerator=accelerator, + parallel_devices=parallel_devices, + checkpoint_io=checkpoint_io or HPUCheckpointIO(), + precision_plugin=precision_plugin, + process_group_backend=process_group_backend, + ) + + def setup_environment(self) -> None: + # This function is used to load Habana libraries required for PyTorch + # to register HPU as one of the available devices. + load_habana_module() + + os.environ["ID"] = str(self.local_rank) + if self._process_group_backend == "hccl": + # this env is used in overrides to check the backend initiated + os.environ["HCCL_DISTRIBUTED_BACKEND"] = str(1) + super().setup_environment() + + def determine_ddp_device_ids(self) -> None: + return None + + def pre_configure_ddp(self): # type: ignore + # if unset, default `find_unused_parameters` `True` + # Many models require setting this parameter to True, as there are corner cases + # when not all parameter backward hooks are fired by the autograd engine even if require_grad is set to True. + # This flag does come with a performance hit, so it is suggested to disable in cases where it is possible. + self._ddp_kwargs["find_unused_parameters"] = self._ddp_kwargs.get("find_unused_parameters", True) + + self._static_graph = False + static_graph = self._ddp_kwargs.get("static_graph") + if static_graph: + # when _set_static_graph() is called find_unused_parameters does not have any significance. + # Resetting the value of find_unused_parameters to False which is the default value to DDP + self._ddp_kwargs["find_unused_parameters"] = False + self._static_graph = True + if static_graph is not None: + # DDP does not accept static_graph as a parameter, hence removing it from the list + del self._ddp_kwargs["static_graph"] + + def configure_ddp(self) -> None: + # DDP does not accept static graph as param with torch < 1.11 + if _TORCH_LESSER_EQUAL_1_10_2: + log.detail(f"{self.__class__.__name__}: configuring DistributedDataParallel") + self.pre_configure_ddp() + self.model = self._setup_model(LightningDistributedModule(self.model)) # type: ignore + if self.root_device.type == "hpu" and self._static_graph: + self._model._set_static_graph() # type: ignore + self._register_ddp_hooks() + else: + self.configure_ddp() + + def broadcast(self, obj: object, src: int = 0) -> object: # type: ignore + obj = [obj] + if self.global_rank != src: + obj = [None] + + broadcast_object_list(obj, src, group=_group.WORLD) + return obj[0] + + def teardown(self) -> None: + log.detail(f"{self.__class__.__name__}: tearing down strategy.") + super().teardown() + + log.detail(f"{self.__class__.__name__}: moving model to CPU") + self.lightning_module.cpu() # type: ignore + # Was set to local rank + os.environ.pop("ID", None) + os.environ.pop("HCCL_DISTRIBUTED_BACKEND", None) + + @classmethod + def register_strategies(cls, strategy_registry: Dict) -> None: + strategy_registry.register( + cls.strategy_name, + cls, + description=f"{cls.__class__.__name__}", + ) diff --git a/pytorch_lightning/strategies/single_hpu.py b/pytorch_lightning/strategies/single_hpu.py new file mode 100644 index 0000000000000..edafe63441906 --- /dev/null +++ b/pytorch_lightning/strategies/single_hpu.py @@ -0,0 +1,80 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Dict, Optional + +import pytorch_lightning as pl +from pytorch_lightning.plugins.io.hpu_plugin import HPUCheckpointIO +from pytorch_lightning.plugins.precision import PrecisionPlugin +from pytorch_lightning.strategies.single_device import SingleDeviceStrategy +from pytorch_lightning.utilities import _HPU_AVAILABLE +from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.types import _DEVICE + +if _HPU_AVAILABLE: + import habana_frameworks.torch.core.hccl # noqa: F401 + from habana_frameworks.torch.utils.library_loader import load_habana_module + + +class SingleHPUStrategy(SingleDeviceStrategy): + """Strategy for training on single HPU device.""" + + strategy_name = "hpu_single" + + def __init__( + self, + device: _DEVICE = "hpu", + accelerator: Optional["pl.accelerators.accelerator.Accelerator"] = None, + checkpoint_io: Optional[HPUCheckpointIO] = None, + precision_plugin: Optional[PrecisionPlugin] = None, + ): + + if not _HPU_AVAILABLE: + raise MisconfigurationException("`SingleHPUStrategy` requires HPU devices to run") + + # This function is used to load Habana libraries required for PyTorch + # to register HPU as one of the available devices. + load_habana_module() + + super().__init__( + accelerator=accelerator, + device=device, + checkpoint_io=checkpoint_io or HPUCheckpointIO(), + precision_plugin=precision_plugin, + ) + + @property + def is_distributed(self) -> bool: + return False + + def setup(self, trainer: "pl.Trainer") -> None: + self.model_to_device() + super().setup(trainer) + + def setup_optimizers(self, trainer: "pl.Trainer") -> None: + super().setup_optimizers(trainer) + + if len(self.optimizers) > 1: + raise MisconfigurationException("HPUs currently support only one optimizer.") + + def model_to_device(self) -> None: + self.model.to(self.root_device) # type: ignore + + @classmethod + def register_strategies(cls, strategy_registry: Dict) -> None: + strategy_registry.register( + cls.strategy_name, + cls, + description=f"{cls.__class__.__name__}", + ) diff --git a/pytorch_lightning/trainer/connectors/accelerator_connector.py b/pytorch_lightning/trainer/connectors/accelerator_connector.py index 43f87ec731acf..28c6da86c6d59 100644 --- a/pytorch_lightning/trainer/connectors/accelerator_connector.py +++ b/pytorch_lightning/trainer/connectors/accelerator_connector.py @@ -22,6 +22,7 @@ from pytorch_lightning.accelerators.accelerator import Accelerator from pytorch_lightning.accelerators.cpu import CPUAccelerator from pytorch_lightning.accelerators.gpu import GPUAccelerator +from pytorch_lightning.accelerators.hpu import HPUAccelerator from pytorch_lightning.accelerators.ipu import IPUAccelerator from pytorch_lightning.accelerators.registry import AcceleratorRegistry from pytorch_lightning.accelerators.tpu import TPUAccelerator @@ -31,6 +32,7 @@ DeepSpeedPrecisionPlugin, DoublePrecisionPlugin, FullyShardedNativeMixedPrecisionPlugin, + HPUPrecisionPlugin, IPUPrecisionPlugin, NativeMixedPrecisionPlugin, PLUGIN_INPUT, @@ -58,8 +60,10 @@ DDPStrategy, DeepSpeedStrategy, HorovodStrategy, + HPUParallelStrategy, IPUStrategy, SingleDeviceStrategy, + SingleHPUStrategy, SingleTPUStrategy, Strategy, StrategyRegistry, @@ -77,6 +81,7 @@ from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.imports import ( _HOROVOD_AVAILABLE, + _HPU_AVAILABLE, _IPU_AVAILABLE, _TORCH_GREATER_EQUAL_1_8, _TPU_AVAILABLE, @@ -187,7 +192,6 @@ def __init__( self._check_device_config_and_set_final_flags( devices=devices, num_nodes=num_nodes, num_processes=num_processes, gpus=gpus, ipus=ipus, tpu_cores=tpu_cores ) - # 2. Instantiate Accelerator # handle `auto` and `None` self._set_accelerator_if_ipu_strategy_is_passed() @@ -427,7 +431,7 @@ def _check_device_config_and_set_final_flags( if self._devices_flag == "auto" and self._accelerator_flag is None: raise MisconfigurationException( f"You passed `devices={devices}` but haven't specified" - " `accelerator=('auto'|'tpu'|'gpu'|'ipu'|'cpu')` for the devices mapping" + " `accelerator=('auto'|'tpu'|'gpu'|'ipu'|'cpu'|'hpu)` for the devices mapping" ) def _map_deprecated_devices_specfic_info_to_accelerator_and_device_flag( @@ -481,6 +485,8 @@ def _choose_accelerator(self) -> str: return "tpu" if _IPU_AVAILABLE: return "ipu" + if _HPU_AVAILABLE: + return "hpu" if torch.cuda.is_available() and torch.cuda.device_count() > 0: return "gpu" return "cpu" @@ -545,6 +551,11 @@ def _is_slurm_managing_tasks(self) -> bool: def _choose_strategy(self) -> Union[Strategy, str]: if self._accelerator_flag == "ipu": return IPUStrategy.strategy_name + if self._accelerator_flag == "hpu": + if self._parallel_devices and len(self._parallel_devices) > 1: + return HPUParallelStrategy.strategy_name + else: + return SingleHPUStrategy(device=torch.device("hpu")) if self._accelerator_flag == "tpu": if self._parallel_devices and len(self._parallel_devices) > 1: return TPUSpawnStrategy.strategy_name @@ -642,6 +653,8 @@ def _check_and_init_precision(self) -> PrecisionPlugin: if isinstance(self.accelerator, IPUAccelerator): return IPUPrecisionPlugin(self._precision_flag) # type: ignore + if isinstance(self.accelerator, HPUAccelerator): + return HPUPrecisionPlugin(self._precision_flag) # type: ignore if isinstance(self.accelerator, TPUAccelerator): if self._precision_flag == 32: return TPUPrecisionPlugin() @@ -707,6 +720,11 @@ def _validate_precision_choice(self) -> None: f"The `TPUAccelerator` can only be used with a `TPUPrecisionPlugin`," f" found: {self._precision_plugin_flag}." ) + if isinstance(self.accelerator, HPUAccelerator): + if self._precision_flag not in (16, "bf16", 32): + raise MisconfigurationException( + f"`Trainer(accelerator='hpu', precision={self._precision_flag!r})` is not supported." + ) if ( self._precision_flag == 16 and isinstance(self.accelerator, CPUAccelerator) @@ -768,7 +786,15 @@ def _lazy_init_strategy(self) -> None: ): raise ValueError( "The `TPUAccelerator` can only be used with a `SingleTPUStrategy` or `TPUSpawnStrategy`," - f" found {self.strategy}." + f" found {self.strategy.__class__.__name__}." + ) + + if isinstance(self.accelerator, HPUAccelerator) and not isinstance( + self.strategy, (SingleHPUStrategy, HPUParallelStrategy) + ): + raise ValueError( + "The `HPUAccelerator` can only be used with a `SingleHPUStrategy` or `HPUParallelStrategy`," + f" found {self.strategy.__class__.__name__}." ) """The following properties are here for backward-compatibility and will be deprecated and removed in favor @@ -801,6 +827,7 @@ def is_distributed(self) -> bool: DeepSpeedStrategy, TPUSpawnStrategy, HorovodStrategy, + HPUParallelStrategy, ) is_distributed = isinstance(self.strategy, distributed_strategy) if isinstance(self.accelerator, TPUAccelerator): diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index fab1a5e9b9d83..bcda0764ff11f 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -31,7 +31,7 @@ from torch.utils.data import DataLoader import pytorch_lightning as pl -from pytorch_lightning.accelerators import Accelerator, GPUAccelerator, IPUAccelerator, TPUAccelerator +from pytorch_lightning.accelerators import Accelerator, GPUAccelerator, HPUAccelerator, IPUAccelerator, TPUAccelerator from pytorch_lightning.callbacks import Callback, EarlyStopping, ModelCheckpoint, ProgressBarBase from pytorch_lightning.callbacks.prediction_writer import BasePredictionWriter from pytorch_lightning.core.datamodule import LightningDataModule @@ -77,6 +77,7 @@ from pytorch_lightning.tuner.lr_finder import _LRFinder from pytorch_lightning.tuner.tuning import Tuner from pytorch_lightning.utilities import ( + _HPU_AVAILABLE, _IPU_AVAILABLE, _TPU_AVAILABLE, AMPType, @@ -195,7 +196,7 @@ def __init__( Args: - accelerator: Supports passing different accelerator types ("cpu", "gpu", "tpu", "ipu", "auto") + accelerator: Supports passing different accelerator types ("cpu", "gpu", "tpu", "ipu", "hpu", "auto") as well as custom accelerator instances. .. deprecated:: v1.5 @@ -353,7 +354,7 @@ def __init__( Default: ``None``. precision: Double precision (64), full precision (32), half precision (16) or bfloat16 precision (bf16). - Can be used on CPU, GPU, TPUs or IPUs. + Can be used on CPU, GPU, TPUs, HPUs or IPUs. Default: ``32``. max_epochs: Stop training once this number of epochs is reached. Disabled by default (None). @@ -1818,6 +1819,9 @@ def _log_device_info(self) -> None: num_ipus = self.num_devices if isinstance(self.accelerator, IPUAccelerator) else 0 rank_zero_info(f"IPU available: {_IPU_AVAILABLE}, using: {num_ipus} IPUs") + num_hpus = self.num_devices if isinstance(self.accelerator, HPUAccelerator) else 0 + rank_zero_info(f"HPU available: {_HPU_AVAILABLE}, using: {num_hpus} HPUs") + if torch.cuda.is_available() and not isinstance(self.accelerator, GPUAccelerator): rank_zero_warn( "GPU available but not used. Set `accelerator` and `devices` using" @@ -1837,6 +1841,12 @@ def _log_device_info(self) -> None: f" `Trainer(accelerator='ipu', devices={IPUAccelerator.auto_device_count()})`." ) + if _HPU_AVAILABLE and not isinstance(self.accelerator, HPUAccelerator): + rank_zero_warn( + "HPU available but not used. Set `accelerator` and `devices` using" + f" `Trainer(accelerator='hpu', devices={HPUAccelerator.auto_device_count()})`." + ) + """ Data loading methods """ diff --git a/pytorch_lightning/utilities/__init__.py b/pytorch_lightning/utilities/__init__.py index 916532d964952..930467e50108d 100644 --- a/pytorch_lightning/utilities/__init__.py +++ b/pytorch_lightning/utilities/__init__.py @@ -36,6 +36,7 @@ _FAIRSCALE_OSS_FP16_BROADCAST_AVAILABLE, _GROUP_AVAILABLE, _HOROVOD_AVAILABLE, + _HPU_AVAILABLE, _HYDRA_AVAILABLE, _HYDRA_EXPERIMENTAL_AVAILABLE, _IPU_AVAILABLE, diff --git a/pytorch_lightning/utilities/device_parser.py b/pytorch_lightning/utilities/device_parser.py index d7b8a319ea4d2..5f886a4e36598 100644 --- a/pytorch_lightning/utilities/device_parser.py +++ b/pytorch_lightning/utilities/device_parser.py @@ -243,3 +243,24 @@ def _parse_tpu_cores_str(tpu_cores: str) -> Union[int, List[int]]: if tpu_cores in ("1", "8"): return int(tpu_cores) return [int(x.strip()) for x in tpu_cores.split(",") if len(x) > 0] + + +def parse_hpus(devices: Optional[Union[int, str, List[int]]]) -> Optional[int]: + """ + Parses the hpus given in the format as accepted by the + :class:`~pytorch_lightning.trainer.Trainer` for the `devices` flag. + + Args: + devices: An integer that indicates the number of Gaudi devices to be used + + Returns: + Either an integer or ``None`` if no devices were requested + + Raises: + MisconfigurationException: + If devices aren't of type `int` or `str` + """ + if devices is not None and not isinstance(devices, (int, str)): + raise MisconfigurationException("`devices` for `HPUAccelerator` must be int, string or None.") + + return int(devices) if isinstance(devices, str) else devices diff --git a/pytorch_lightning/utilities/distributed.py b/pytorch_lightning/utilities/distributed.py index d5dca963f12c4..b94351b22a335 100644 --- a/pytorch_lightning/utilities/distributed.py +++ b/pytorch_lightning/utilities/distributed.py @@ -21,7 +21,12 @@ from torch.nn.parallel.distributed import DistributedDataParallel import pytorch_lightning as pl -from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_8, _TORCH_GREATER_EQUAL_1_9, _TPU_AVAILABLE +from pytorch_lightning.utilities.imports import ( + _HPU_AVAILABLE, + _TORCH_GREATER_EQUAL_1_8, + _TORCH_GREATER_EQUAL_1_9, + _TPU_AVAILABLE, +) from pytorch_lightning.utilities.rank_zero import rank_zero_debug as new_rank_zero_debug from pytorch_lightning.utilities.rank_zero import rank_zero_only # noqa: F401 from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation @@ -124,6 +129,14 @@ def sync_ddp( else: op = reduce_op + # WA for HPU. HPU doesn't support Long types, forcefully set it to float + if _HPU_AVAILABLE: + is_hpu_backend = os.environ.get("HCCL_DISTRIBUTED_BACKEND") == "1" + if is_hpu_backend: + if (result.type() == "torch.LongTensor") or (result.type() == "torch.hpu.LongTensor"): + new_rank_zero_info("Long tensor unsupported on HPU, casting to float") + result = result.float() + # sync all processes before reduction torch.distributed.barrier(group=group) torch.distributed.all_reduce(result, op=op, group=group, async_op=False) diff --git a/pytorch_lightning/utilities/enums.py b/pytorch_lightning/utilities/enums.py index 105b167a29910..f9ae0c82d0444 100644 --- a/pytorch_lightning/utilities/enums.py +++ b/pytorch_lightning/utilities/enums.py @@ -123,6 +123,7 @@ class DistributedType(LightningEnum, metaclass=_OnAccessEnumMeta): DDP_SHARDED = "ddp_sharded" DDP_SHARDED_SPAWN = "ddp_sharded_spawn" DDP_FULLY_SHARDED = "ddp_fully_sharded" + HPU_PARALLEL = "hpu_parallel" @staticmethod def interactive_compatible_types() -> list[DistributedType]: @@ -248,6 +249,7 @@ class _StrategyType(LightningEnum): DDP_SHARDED_SPAWN = "ddp_sharded_spawn" DDP_FULLY_SHARDED = "ddp_fully_sharded" BAGUA = "bagua" + HPU_PARALLEL = "hpu_parallel" @staticmethod def interactive_compatible_types() -> list[_StrategyType]: @@ -279,6 +281,7 @@ class _AcceleratorType(LightningEnum): GPU = "GPU" IPU = "IPU" TPU = "TPU" + HPU = "HPU" class _FaultTolerantMode(LightningEnum): diff --git a/pytorch_lightning/utilities/imports.py b/pytorch_lightning/utilities/imports.py index 8795fa66d5fd3..c98874bcbd742 100644 --- a/pytorch_lightning/utilities/imports.py +++ b/pytorch_lightning/utilities/imports.py @@ -94,6 +94,7 @@ def _compare_version(package: str, op: Callable, version: str, use_base_version: _TORCH_GREATER_EQUAL_1_9 = _compare_version("torch", operator.ge, "1.9.0") _TORCH_GREATER_EQUAL_1_9_1 = _compare_version("torch", operator.ge, "1.9.1") _TORCH_GREATER_EQUAL_1_10 = _compare_version("torch", operator.ge, "1.10.0") +_TORCH_LESSER_EQUAL_1_10_2 = _compare_version("torch", operator.le, "1.10.2") _TORCH_GREATER_EQUAL_1_11 = _compare_version("torch", operator.ge, "1.11.0") _APEX_AVAILABLE = _module_available("apex.amp") @@ -112,6 +113,7 @@ def _compare_version(package: str, op: Callable, version: str, use_base_version: _NEPTUNE_GREATER_EQUAL_0_9 = _NEPTUNE_AVAILABLE and _compare_version("neptune", operator.ge, "0.9.0") _OMEGACONF_AVAILABLE = _package_available("omegaconf") _POPTORCH_AVAILABLE = _package_available("poptorch") +_HABANA_FRAMEWORK_AVAILABLE = _package_available("habana_frameworks") _RICH_AVAILABLE = _package_available("rich") and _compare_version("rich", operator.ge, "10.2.2") _TORCH_QUANTIZE_AVAILABLE = bool([eg for eg in torch.backends.quantized.supported_engines if eg != "none"]) _TORCHTEXT_AVAILABLE = _package_available("torchtext") @@ -134,6 +136,13 @@ def _compare_version(package: str, op: Callable, version: str, use_base_version: else: _IPU_AVAILABLE = False +if _HABANA_FRAMEWORK_AVAILABLE: + from habana_frameworks.torch.utils.library_loader import is_habana_avaialble + + _HPU_AVAILABLE = is_habana_avaialble() +else: + _HPU_AVAILABLE = False + # experimental feature within PyTorch Lightning. def _fault_tolerant_training() -> bool: diff --git a/requirements/test.txt b/requirements/test.txt index c7d2860ebee61..51d9ecf71db44 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -6,6 +6,9 @@ twine==3.2 mypy>=0.920 flake8>=3.9.2 pre-commit>=1.0 +pytest-forked +sklearn +jsonargparse # needed in tests cloudpickle>=1.3 diff --git a/tests/accelerators/test_accelerator_connector.py b/tests/accelerators/test_accelerator_connector.py index 794cb9b2922cd..8c29dfb09a0d0 100644 --- a/tests/accelerators/test_accelerator_connector.py +++ b/tests/accelerators/test_accelerator_connector.py @@ -925,7 +925,10 @@ def test_unsupported_ipu_choice(mock_ipu_acc_avail, monkeypatch): @mock.patch("torch.cuda.is_available", return_value=False) @mock.patch("pytorch_lightning.utilities.imports._TPU_AVAILABLE", return_value=False) @mock.patch("pytorch_lightning.utilities.imports._IPU_AVAILABLE", return_value=False) -def test_devices_auto_choice_cpu(is_ipu_available_mock, is_tpu_available_mock, is_gpu_available_mock): +@mock.patch("pytorch_lightning.utilities.imports._HPU_AVAILABLE", return_value=False) +def test_devices_auto_choice_cpu( + is_ipu_available_mock, is_tpu_available_mock, is_gpu_available_mock, is_hpu_available_mock +): trainer = Trainer(accelerator="auto", devices="auto") assert trainer.num_devices == 1 diff --git a/tests/accelerators/test_accelerator_registry.py b/tests/accelerators/test_accelerator_registry.py index b783f63b1e64a..4e2b521873408 100644 --- a/tests/accelerators/test_accelerator_registry.py +++ b/tests/accelerators/test_accelerator_registry.py @@ -63,4 +63,4 @@ def is_available(): def test_available_accelerators_in_registry(): - assert AcceleratorRegistry.available_accelerators() == ["cpu", "gpu", "ipu", "tpu"] + assert AcceleratorRegistry.available_accelerators() == ["cpu", "gpu", "hpu", "ipu", "tpu"] diff --git a/tests/accelerators/test_hpu.py b/tests/accelerators/test_hpu.py new file mode 100644 index 0000000000000..64415eeb5a0c8 --- /dev/null +++ b/tests/accelerators/test_hpu.py @@ -0,0 +1,242 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os + +import pytest +import torch + +from pytorch_lightning import Callback, seed_everything, Trainer +from pytorch_lightning.accelerators import HPUAccelerator +from pytorch_lightning.strategies.hpu_parallel import HPUParallelStrategy +from pytorch_lightning.strategies.single_hpu import SingleHPUStrategy +from pytorch_lightning.utilities import _HPU_AVAILABLE +from pytorch_lightning.utilities.exceptions import MisconfigurationException +from tests.helpers.boring_model import BoringModel +from tests.helpers.datamodules import ClassifDataModule +from tests.helpers.runif import RunIf +from tests.helpers.simple_models import ClassificationModel + + +@RunIf(hpu=True) +def test_availability(): + assert HPUAccelerator.is_available() + + +@pytest.mark.skipif(_HPU_AVAILABLE, reason="test requires non-HPU machine") +def test_fail_if_no_hpus(): + with pytest.raises(MisconfigurationException, match="HPUAccelerator can not run on your system"): + Trainer(accelerator="hpu", devices=1) + + +@RunIf(hpu=True) +def test_accelerator_selected(): + trainer = Trainer(accelerator="hpu") + assert isinstance(trainer.accelerator, HPUAccelerator) + + +@RunIf(hpu=True) +def test_all_stages(tmpdir, hpus): + """Tests all the model stages using BoringModel on HPU.""" + model = BoringModel() + + trainer = Trainer( + default_root_dir=tmpdir, + fast_dev_run=True, + accelerator="hpu", + devices=hpus, + precision=16, + ) + trainer.fit(model) + trainer.validate(model) + trainer.test(model) + trainer.predict(model) + + +@RunIf(hpu=True) +def test_optimization(tmpdir): + seed_everything(42) + + dm = ClassifDataModule(length=1024) + model = ClassificationModel() + + trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, accelerator="hpu", devices=1) + + # fit model + trainer.fit(model, dm) + assert trainer.state.finished, f"Training failed with {trainer.state}" + assert dm.trainer is not None + + # validate + result = trainer.validate(datamodule=dm) + assert dm.trainer is not None + assert result[0]["val_acc"] > 0.7 + + # test + result = trainer.test(model, datamodule=dm) + assert dm.trainer is not None + test_result = result[0]["test_acc"] + assert test_result > 0.6 + + # test saved model + model_path = os.path.join(tmpdir, "model.pt") + trainer.save_checkpoint(model_path) + + model = ClassificationModel.load_from_checkpoint(model_path) + + trainer = Trainer(default_root_dir=tmpdir, accelerator="hpu", devices=1) + + result = trainer.test(model, datamodule=dm) + saved_result = result[0]["test_acc"] + assert saved_result == test_result + + +@RunIf(hpu=True) +def test_stages_correct(tmpdir): + """Ensure all stages correctly are traced correctly by asserting the output for each stage.""" + + class StageModel(BoringModel): + def training_step(self, batch, batch_idx): + loss = super().training_step(batch, batch_idx) + loss = loss.get("loss") + # tracing requires a loss value that depends on the model. + # force it to be a value but ensure we use the loss. + loss = (loss - loss) + torch.tensor(1) + return {"loss": loss} + + def validation_step(self, batch, batch_idx): + loss = super().validation_step(batch, batch_idx) + x = loss.get("x") + x = (x - x) + torch.tensor(2) + return {"x": x} + + def test_step(self, batch, batch_idx): + loss = super().test_step(batch, batch_idx) + y = loss.get("y") + y = (y - y) + torch.tensor(3) + return {"y": y} + + def predict_step(self, batch, batch_idx, dataloader_idx=None): + output = super().predict_step(batch, batch_idx) + return (output - output) + torch.tensor(4) + + class TestCallback(Callback): + def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx) -> None: + assert outputs["loss"].item() == 1 + + def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx) -> None: + assert outputs["x"].item() == 2 + + def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx) -> None: + assert outputs["y"].item() == 3 + + def on_predict_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx) -> None: + assert torch.all(outputs == 4).item() + + model = StageModel() + trainer = Trainer( + default_root_dir=tmpdir, fast_dev_run=True, accelerator="hpu", devices=1, callbacks=TestCallback() + ) + trainer.fit(model) + trainer.test(model) + trainer.validate(model) + trainer.predict(model) + + +@RunIf(hpu=True) +def test_accelerator_hpu(): + + trainer = Trainer(accelerator="hpu", devices=1) + assert isinstance(trainer.accelerator, HPUAccelerator) + assert trainer.num_devices == 1 + + trainer = Trainer(accelerator="hpu") + assert isinstance(trainer.accelerator, HPUAccelerator) + assert trainer.num_devices == 8 + + trainer = Trainer(accelerator="auto", devices=8) + assert isinstance(trainer.accelerator, HPUAccelerator) + assert trainer.num_devices == 8 + + +@RunIf(hpu=True) +def test_accelerator_hpu_with_single_device(): + + trainer = Trainer(accelerator="hpu", devices=1) + + assert isinstance(trainer.strategy, SingleHPUStrategy) + assert isinstance(trainer.accelerator, HPUAccelerator) + + +@RunIf(hpu=True) +def test_accelerator_hpu_with_multiple_devices(): + + trainer = Trainer(accelerator="hpu", devices=8) + + assert isinstance(trainer.strategy, HPUParallelStrategy) + assert isinstance(trainer.accelerator, HPUAccelerator) + + +@RunIf(hpu=True) +def test_accelerator_auto_with_devices_hpu(): + + trainer = Trainer(accelerator="auto", devices=8) + + assert isinstance(trainer.strategy, HPUParallelStrategy) + + +@RunIf(hpu=True) +def test_strategy_choice_hpu_plugin(): + trainer = Trainer(strategy=SingleHPUStrategy(device=torch.device("hpu")), accelerator="hpu", devices=1) + assert isinstance(trainer.strategy, SingleHPUStrategy) + + trainer = Trainer(accelerator="hpu", devices=1) + assert isinstance(trainer.strategy, SingleHPUStrategy) + + +@RunIf(hpu=True) +def test_strategy_choice_hpu_parallel_plugin(): + trainer = Trainer( + strategy=HPUParallelStrategy(parallel_devices=[torch.device("hpu")] * 8), accelerator="hpu", devices=8 + ) + assert isinstance(trainer.strategy, HPUParallelStrategy) + + trainer = Trainer(accelerator="hpu", devices=8) + assert isinstance(trainer.strategy, HPUParallelStrategy) + + +@RunIf(hpu=True) +def test_devices_auto_choice_hpu(): + trainer = Trainer(accelerator="auto", devices="auto") + assert trainer.num_devices == 8 + + +@RunIf(hpu=True) +@pytest.mark.parametrize("hpus", [1]) +def test_inference_only(tmpdir, hpus): + model = BoringModel() + + trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, accelerator="hpu", devices=hpus) + trainer.validate(model) + trainer.test(model) + trainer.predict(model) + + +def test_hpu_auto_device_count(): + assert HPUAccelerator.auto_device_count() == 8 + + +@RunIf(hpu=True) +def test_hpu_unsupported_device_type(): + with pytest.raises(MisconfigurationException, match="`devices` for `HPUAccelerator` must be int, string or None."): + Trainer(accelerator="hpu", devices=[1]) diff --git a/tests/conftest.py b/tests/conftest.py index cdb1bfbd39392..b24f1dad8c61b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -198,3 +198,19 @@ def pytest_collection_modifyitems(items): # has `@RunIf(ipu=True)` if marker.name == "skipif" and marker.kwargs.get("ipu") ] + + +def pytest_addoption(parser): + parser.addoption("--hpus", action="store", type=int, default=1, help="Number of hpus 1-8") + parser.addoption( + "--hmp-bf16", action="store", type=str, default="./ops_bf16_mnist.txt", help="bf16 ops list file in hmp O1 mode" + ) + parser.addoption( + "--hmp-fp32", action="store", type=str, default="./ops_fp32_mnist.txt", help="fp32 ops list file in hmp O1 mode" + ) + + +@pytest.fixture +def hpus(request): + hpus = request.config.getoption("--hpus") + return hpus diff --git a/tests/helpers/runif.py b/tests/helpers/runif.py index f52aada92ea30..addc8aec16450 100644 --- a/tests/helpers/runif.py +++ b/tests/helpers/runif.py @@ -27,6 +27,7 @@ _FAIRSCALE_AVAILABLE, _FAIRSCALE_FULLY_SHARDED_AVAILABLE, _HOROVOD_AVAILABLE, + _HPU_AVAILABLE, _IPU_AVAILABLE, _OMEGACONF_AVAILABLE, _RICH_AVAILABLE, @@ -68,6 +69,7 @@ def __new__( amp_apex: bool = False, tpu: bool = False, ipu: bool = False, + hpu: bool = False, horovod: bool = False, horovod_nccl: bool = False, skip_windows: bool = False, @@ -94,6 +96,7 @@ def __new__( amp_apex: Require that NVIDIA/apex is installed. tpu: Require that TPU is available. ipu: Require that IPU is available. + hpu: Require that HPU is available. horovod: Require that Horovod is installed. horovod_nccl: Require that Horovod is installed with NCCL support. skip_windows: Skip for Windows platform. @@ -154,6 +157,10 @@ def __new__( reasons.append("IPU") kwargs["ipu"] = True + if hpu: + conditions.append(not _HPU_AVAILABLE) + reasons.append("HPU") + if horovod: conditions.append(not _HOROVOD_AVAILABLE) reasons.append("Horovod") diff --git a/tests/plugins/precision/hpu/ops_bf16.txt b/tests/plugins/precision/hpu/ops_bf16.txt new file mode 100644 index 0000000000000..53ec99c15b4ce --- /dev/null +++ b/tests/plugins/precision/hpu/ops_bf16.txt @@ -0,0 +1,2 @@ +linear +relu diff --git a/tests/plugins/precision/hpu/ops_fp32.txt b/tests/plugins/precision/hpu/ops_fp32.txt new file mode 100644 index 0000000000000..4509b7e58ac29 --- /dev/null +++ b/tests/plugins/precision/hpu/ops_fp32.txt @@ -0,0 +1 @@ +cross_entropy diff --git a/tests/plugins/precision/hpu/test_hpu.py b/tests/plugins/precision/hpu/test_hpu.py new file mode 100644 index 0000000000000..5701bf2dc2caa --- /dev/null +++ b/tests/plugins/precision/hpu/test_hpu.py @@ -0,0 +1,96 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Optional + +import pytest +import torch + +from pytorch_lightning import Callback, LightningModule, Trainer +from pytorch_lightning.plugins import HPUPrecisionPlugin +from pytorch_lightning.strategies.single_hpu import SingleHPUStrategy +from tests.helpers.boring_model import BoringModel +from tests.helpers.runif import RunIf + + +@pytest.fixture +def hmp_params(request): + return { + "opt_level": "O1", + "verbose": False, + "bf16_file_path": request.config.getoption("--hmp-bf16"), + "fp32_file_path": request.config.getoption("--hmp-fp32"), + } + + +@RunIf(hpu=True) +def test_precision_plugin(hmp_params): + plugin = HPUPrecisionPlugin(precision="bf16", **hmp_params) + assert plugin.precision == "bf16" + + +@RunIf(hpu=True) +def test_mixed_precision(tmpdir, hmp_params: dict): + class TestCallback(Callback): + def setup(self, trainer: Trainer, pl_module: LightningModule, stage: Optional[str] = None) -> None: + assert trainer.strategy.model.precision == "bf16" + raise SystemExit + + model = BoringModel() + trainer = Trainer( + default_root_dir=tmpdir, + fast_dev_run=True, + accelerator="hpu", + devices=1, + plugins=[HPUPrecisionPlugin(precision="bf16", **hmp_params)], + callbacks=TestCallback(), + ) + assert isinstance(trainer.strategy, SingleHPUStrategy) + assert isinstance(trainer.strategy.precision_plugin, HPUPrecisionPlugin) + assert trainer.strategy.precision_plugin.precision == "bf16" + with pytest.raises(SystemExit): + trainer.fit(model) + + +@RunIf(hpu=True) +def test_pure_half_precision(tmpdir, hmp_params: dict): + class TestCallback(Callback): + def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None: + assert trainer.strategy.model.precision == 16 + for param in trainer.strategy.model.parameters(): + assert param.dtype == torch.float16 + raise SystemExit + + model = BoringModel() + model = model.half() + trainer = Trainer( + default_root_dir=tmpdir, + fast_dev_run=True, + accelerator="hpu", + devices=1, + plugins=[HPUPrecisionPlugin(precision=16, **hmp_params)], + callbacks=TestCallback(), + ) + + assert isinstance(trainer.strategy, SingleHPUStrategy) + assert isinstance(trainer.strategy.precision_plugin, HPUPrecisionPlugin) + assert trainer.strategy.precision_plugin.precision == 16 + + with pytest.raises(SystemExit): + trainer.fit(model) + + +@RunIf(hpu=True) +def test_unsupported_precision_plugin(): + with pytest.raises(ValueError, match=r"accelerator='hpu', precision='mixed'\)` is not supported."): + HPUPrecisionPlugin(precision="mixed")