-
Notifications
You must be signed in to change notification settings - Fork 3.5k
[WIP] Add support for HPU accelerator #10404
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
import os | ||
import sys | ||
|
||
import habana_frameworks.torch.core as htcore | ||
import torch | ||
from torch import nn | ||
from torch.nn import functional as F | ||
from torch.utils.data import DataLoader, random_split | ||
from torchvision import transforms | ||
from torchvision.datasets import MNIST | ||
|
||
import pytorch_lightning as pl | ||
|
||
|
||
class MNISTModel(pl.LightningModule): | ||
def __init__(self): | ||
super().__init__() | ||
self.l1 = torch.nn.Linear(28 * 28, 10) | ||
|
||
def forward(self, x): | ||
return torch.relu(self.l1(x.view(x.size(0), -1))) | ||
|
||
def training_step(self, batch, batch_nb): | ||
x, y = batch | ||
loss = F.cross_entropy(self(x), y) | ||
return loss | ||
|
||
def configure_optimizers(self): | ||
return torch.optim.Adam(self.parameters(), lr=0.02) | ||
|
||
|
||
# Init our model | ||
mnist_model = MNISTModel() | ||
|
||
# Init DataLoader from MNIST Dataset | ||
train_ds = MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()) | ||
train_loader = DataLoader(train_ds, batch_size=32) | ||
|
||
# TBD: import these keys from hmp | ||
hmp_keys = ["level", "verbose", "bf16_ops", "fp32_ops"] | ||
hmp_params = dict.fromkeys(hmp_keys) | ||
hmp_params["level"] = "O1" | ||
hmp_params["verbose"] = False | ||
hmp_params["bf16_ops"] = "./pytorch-lightning-fork/pl_examples/hpu_examples/simple_mnist/ops_bf16_mnist.txt" | ||
hmp_params["fp32_ops"] = "./pytorch-lightning-fork/pl_examples/hpu_examples/simple_mnist/ops_fp32_mnist.txt" | ||
|
||
# Initialize a trainer | ||
trainer = pl.Trainer(hpus=1, max_epochs=1, precision=16, hmp_params=hmp_params) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We won't support adding a new flag within the Trainer such as There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @tchaton Do you have a recommendation ? Its nice to generalize this as a common param based on the backend but would involve modifying amp |
||
|
||
# Train the model ⚡ | ||
trainer.fit(mnist_model, train_loader) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
linear | ||
relu |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
cross_entropy |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
# Copyright The PyTorch Lightning team. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
import logging | ||
import os | ||
from typing import Any | ||
|
||
import torch | ||
|
||
import pytorch_lightning as pl | ||
from pytorch_lightning.accelerators.accelerator import Accelerator | ||
from pytorch_lightning.plugins import DataParallelPlugin | ||
from pytorch_lightning.plugins.precision.hpu_precision import HPUPrecisionPlugin | ||
from pytorch_lightning.plugins.training_type.ddp import DDPPlugin | ||
from pytorch_lightning.plugins.training_type.hpu import HPUPlugin | ||
from pytorch_lightning.utilities.exceptions import MisconfigurationException | ||
|
||
_log = logging.getLogger(__name__) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do you use this? :) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Will clean it up |
||
|
||
|
||
class HPUAccelerator(Accelerator): | ||
"""Accelerator for HPU devices.""" | ||
|
||
def setup(self, trainer: "pl.Trainer") -> None: | ||
""" | ||
Raises: | ||
ValueError: | ||
If the precision or training type plugin are unsupported. | ||
""" | ||
if not isinstance(self.precision_plugin, HPUPrecisionPlugin): | ||
# this configuration should have been avoided in the accelerator connector | ||
raise ValueError( | ||
f"The `HPUAccelerator` can only be used with a `HPUPrecisionPlugin`, found: {self.precision_plugin}." | ||
) | ||
if not isinstance(self.training_type_plugin, (HPUPlugin, DDPPlugin)): | ||
raise ValueError( | ||
"The `HPUAccelerator` can only be used with a `HPUPlugin` or `DDPPlugin," | ||
f" found {self.training_type_plugin}." | ||
) | ||
return super().setup(trainer) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -251,6 +251,14 @@ def on_gpu(self): | |
""" | ||
return self.device.type == "cuda" | ||
|
||
@property | ||
def on_hpu(self): | ||
"""True if your model is currently running on HPUs. | ||
|
||
Useful to set flags around the LightningModule for different CPU vs GPU vs HPU behavior. | ||
""" | ||
return self.device.type == "hpu" | ||
|
||
@property | ||
def automatic_optimization(self) -> bool: | ||
"""If set to ``False`` you are responsible for calling ``.backward()``, ``.step()``, ``.zero_grad()``.""" | ||
|
@@ -1586,6 +1594,7 @@ def optimizer_step( | |
optimizer_idx: int = 0, | ||
optimizer_closure: Optional[Callable[[], Any]] = None, | ||
on_tpu: bool = False, | ||
on_hpu: bool = None, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is an API BC. Let's not add it there and I believe it is not necessary either. the LightningOptimizer should handle the optimizer.step properly. |
||
using_native_amp: bool = False, | ||
using_lbfgs: bool = False, | ||
) -> None: | ||
|
@@ -1604,6 +1613,7 @@ def optimizer_step( | |
optimizer_closure: Closure for all optimizers. This closure must be executed as it includes the | ||
calls to ``training_step()``, ``optimizer.zero_grad()``, and ``backward()``. | ||
on_tpu: ``True`` if TPU backward is required | ||
on_hpu: ``True`` if HPU backward is required | ||
using_native_amp: ``True`` if using native amp | ||
using_lbfgs: True if the matching optimizer is :class:`torch.optim.LBFGS` | ||
|
||
|
Original file line number | Diff line number | Diff line change | ||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -3,6 +3,7 @@ | |||||||||||||
|
||||||||||||||
import torch | ||||||||||||||
|
||||||||||||||
from pytorch_lightning.utilities import _HPU_AVAILABLE | ||||||||||||||
from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_8 | ||||||||||||||
|
||||||||||||||
log = logging.getLogger(__name__) | ||||||||||||||
|
@@ -53,13 +54,21 @@ def _broadcast_object_list(object_list, src=0, group=None): | |||||||||||||
|
||||||||||||||
group_backend = get_backend(group) | ||||||||||||||
is_nccl_backend = group_backend == Backend.NCCL | ||||||||||||||
import os | ||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This os import should be top of the file. |
||||||||||||||
|
||||||||||||||
dist_backend = os.environ.get("PL_TORCH_DISTRIBUTED_BACKEND") | ||||||||||||||
SeanNaren marked this conversation as resolved.
Show resolved
Hide resolved
SeanNaren marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||
is_hcl_backend = group_backend == torch.distributed.Backend(str(dist_backend)) | ||||||||||||||
Comment on lines
+57
to
+60
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. All the code in this file will be removed in #10390 together with the support for PyTorch 1.6 |
||||||||||||||
current_device = torch.device("cpu") | ||||||||||||||
if is_nccl_backend: | ||||||||||||||
# See note about using torch.cuda.current_device() here in docstring. | ||||||||||||||
# We cannot simply use my_rank since rank == device is not necessarily | ||||||||||||||
# true. | ||||||||||||||
current_device = torch.device("cuda", torch.cuda.current_device()) | ||||||||||||||
object_sizes_tensor = object_sizes_tensor.to(current_device) | ||||||||||||||
elif is_hcl_backend: | ||||||||||||||
current_device = torch.device("hpu") | ||||||||||||||
# Workaround: HPU doesn't not support long tensors for collectives | ||||||||||||||
object_sizes_tensor = object_sizes_tensor.int() | ||||||||||||||
object_sizes_tensor = object_sizes_tensor.to(current_device) | ||||||||||||||
|
||||||||||||||
# Broadcast object sizes | ||||||||||||||
|
@@ -73,6 +82,8 @@ def _broadcast_object_list(object_list, src=0, group=None): | |||||||||||||
|
||||||||||||||
if is_nccl_backend: | ||||||||||||||
object_tensor = object_tensor.to(current_device) | ||||||||||||||
elif is_hcl_backend: | ||||||||||||||
object_tensor = object_tensor.to(current_device) | ||||||||||||||
Comment on lines
83
to
+86
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||
|
||||||||||||||
broadcast(object_tensor, src=src, group=group) | ||||||||||||||
|
||||||||||||||
|
@@ -93,7 +104,7 @@ def _broadcast_noop(obj, *_, **__): | |||||||||||||
return obj | ||||||||||||||
|
||||||||||||||
broadcast_object_list = _broadcast_noop | ||||||||||||||
elif _TORCH_GREATER_EQUAL_1_8: | ||||||||||||||
elif _TORCH_GREATER_EQUAL_1_8 and not _HPU_AVAILABLE: | ||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Mind confirming the implementation for |
||||||||||||||
from torch.distributed.distributed_c10d import broadcast_object_list | ||||||||||||||
else: | ||||||||||||||
broadcast_object_list = _broadcast_object_list |
Original file line number | Diff line number | Diff line change | ||||||||
---|---|---|---|---|---|---|---|---|---|---|
@@ -0,0 +1,52 @@ | ||||||||||
# Copyright (C) 2021 Habana Labs, Ltd. an Intel Company | ||||||||||
# All Rights Reserved. | ||||||||||
# | ||||||||||
# Unauthorized copying of this file or any element(s) within it, via any medium | ||||||||||
# is strictly prohibited. | ||||||||||
# This file contains Habana Labs, Ltd. proprietary and confidential information | ||||||||||
# and is subject to the confidentiality and license agreements under which it | ||||||||||
# was provided. | ||||||||||
# | ||||||||||
|
||||||||||
# Copyright The PyTorch Lightning team. | ||||||||||
# | ||||||||||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||||||||||
# you may not use this file except in compliance with the License. | ||||||||||
# You may obtain a copy of the License at | ||||||||||
# | ||||||||||
# http://www.apache.org/licenses/LICENSE-2.0 | ||||||||||
# | ||||||||||
# Unless required by applicable law or agreed to in writing, software | ||||||||||
# distributed under the License is distributed on an "AS IS" BASIS, | ||||||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||||||
# See the License for the specific language governing permissions and | ||||||||||
# limitations under the License. | ||||||||||
import os | ||||||||||
from typing import Any, List, Tuple | ||||||||||
|
||||||||||
import torch.nn as nn | ||||||||||
from habana_frameworks.torch.hpex import hmp | ||||||||||
from torch.optim import Optimizer | ||||||||||
|
||||||||||
from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin | ||||||||||
|
||||||||||
|
||||||||||
class HPUPrecisionPlugin(PrecisionPlugin): | ||||||||||
"""Plugin that enables bfloats/floats on HPUs.""" | ||||||||||
|
||||||||||
def __init__(self, precision: int, hmp_params: []) -> None: | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
no mutable defaults, also here it is an empty list and below you check for |
||||||||||
super().__init__() | ||||||||||
self.precision = precision | ||||||||||
if hmp_params is not None: | ||||||||||
hmp_opt_level = hmp_params["level"] | ||||||||||
hmp_bf16 = hmp_params["bf16_ops"] | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Personal preference, I find this API slightly counter-intuitive. Won't it be better to explode the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There are plans to enhance precision params. So it would be easier to maintain it this way and avoid param explosion |
||||||||||
hmp_fp32 = hmp_params["fp32_ops"] | ||||||||||
hmp_verbose = hmp_params["verbose"] | ||||||||||
hmp.convert( | ||||||||||
opt_level=hmp_opt_level, bf16_file_path=hmp_bf16, fp32_file_path=hmp_fp32, isVerbose=hmp_verbose | ||||||||||
) | ||||||||||
|
||||||||||
def connect( | ||||||||||
self, model: nn.Module, optimizers: List[Optimizer], lr_schedulers: List[Any] | ||||||||||
) -> Tuple[nn.Module, List[Optimizer], List[Any]]: | ||||||||||
return super().connect(model=model, optimizers=optimizers, lr_schedulers=lr_schedulers) | ||||||||||
Comment on lines
+49
to
+52
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. no need to override this if just calling super :)
Suggested change
|
Original file line number | Diff line number | Diff line change | ||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -374,7 +374,7 @@ def configure_ddp(self) -> None: | |||||||||||||||
self._register_ddp_hooks() | ||||||||||||||||
|
||||||||||||||||
def determine_ddp_device_ids(self): | ||||||||||||||||
if self.root_device.type == "cpu": | ||||||||||||||||
if self.root_device.type == "cpu" or self.root_device.type == "hpu": | ||||||||||||||||
return None | ||||||||||||||||
return [self.root_device.index] | ||||||||||||||||
|
||||||||||||||||
|
@@ -534,6 +534,14 @@ def reconciliate_processes(self, trace: str) -> None: | |||||||||||||||
shutil.rmtree(sync_dir) | ||||||||||||||||
raise DeadlockDetectedException(f"DeadLock detected from rank: {self.global_rank} \n {trace}") | ||||||||||||||||
|
||||||||||||||||
def on_save(self, checkpoint: Dict[str, Union[Any, torch.Tensor]]) -> Dict[str, Union[Any, torch.Tensor]]: | ||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. on_save has been removed and won't be used anymore. |
||||||||||||||||
if self.root_device.type == "hpu" and self.cluster_environment.global_rank() == 0: | ||||||||||||||||
from pytorch_lightning.utilities.apply_func import move_data_to_device | ||||||||||||||||
|
||||||||||||||||
return move_data_to_device(checkpoint, torch.device("cpu")) | ||||||||||||||||
else: | ||||||||||||||||
return checkpoint | ||||||||||||||||
|
||||||||||||||||
Comment on lines
+537
to
+544
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||
def teardown(self) -> None: | ||||||||||||||||
if isinstance(self.model, DistributedDataParallel): | ||||||||||||||||
self.model = self.lightning_module | ||||||||||||||||
|
Original file line number | Diff line number | Diff line change | ||
---|---|---|---|---|
@@ -0,0 +1,76 @@ | ||||
# Copyright (C) 2021 Habana Labs, Ltd. an Intel Company | ||||
carmocca marked this conversation as resolved.
Show resolved
Hide resolved
|
||||
# All Rights Reserved. | ||||
# | ||||
# Unauthorized copying of this file or any element(s) within it, via any medium | ||||
# is strictly prohibited. | ||||
# This file contains Habana Labs, Ltd. proprietary and confidential information | ||||
# and is subject to the confidentiality and license agreements under which it | ||||
# was provided. | ||||
# | ||||
|
||||
import os | ||||
from typing import Any, Dict, Optional | ||||
|
||||
# Copyright The PyTorch Lightning team. | ||||
# | ||||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||||
# you may not use this file except in compliance with the License. | ||||
# You may obtain a copy of the License at | ||||
# | ||||
# http://www.apache.org/licenses/LICENSE-2.0 | ||||
# | ||||
# Unless required by applicable law or agreed to in writing, software | ||||
# distributed under the License is distributed on an "AS IS" BASIS, | ||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
# See the License for the specific language governing permissions and | ||||
# limitations under the License. | ||||
import torch | ||||
|
||||
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO | ||||
from pytorch_lightning.plugins.training_type.single_device import SingleDevicePlugin | ||||
from pytorch_lightning.utilities import _HPU_AVAILABLE, find_shared_parameters, set_shared_parameters | ||||
from pytorch_lightning.utilities.apply_func import move_data_to_device | ||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException | ||||
from pytorch_lightning.utilities.model_helpers import is_overridden | ||||
from pytorch_lightning.utilities.types import _PATH | ||||
|
||||
|
||||
class HPUPlugin(SingleDevicePlugin): | ||||
def __init__( | ||||
self, | ||||
device: int, | ||||
checkpoint_io: Optional[CheckpointIO] = None, | ||||
debug: bool = False, | ||||
): | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. docstring ? |
||||
|
||||
device = torch.device("hpu") | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We can pass them directly to super().init |
||||
checkpoint_io = checkpoint_io | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||
super().__init__(device, checkpoint_io=checkpoint_io) | ||||
|
||||
self.debug = debug | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what is this flag needed for? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not needed. Will remove it |
||||
|
||||
@property | ||||
def is_distributed(self) -> bool: | ||||
return False | ||||
Comment on lines
+52
to
+54
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Isn't this always the case? if not, it is likely an issue beyond this PR, if yes, no need to override :) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I was wondering, we can have the same plugin do 1x and 8x and can be generalized. So added it. |
||||
|
||||
def setup(self) -> None: | ||||
shared_params = find_shared_parameters(self.model) | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't believe this logic is required for HPU. This is quite specific to TPU which don't support tying. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this is common for all sharding use cases if am not wrong. |
||||
self.model_to_device() | ||||
if is_overridden("on_post_move_to_device", self.lightning_module): | ||||
self.model.on_post_move_to_device() | ||||
else: | ||||
set_shared_parameters(self.model, shared_params) | ||||
|
||||
def model_to_device(self) -> None: | ||||
self.model.to(self.root_device) | ||||
|
||||
@property | ||||
def on_hpu(self) -> bool: | ||||
return True | ||||
|
||||
def pre_dispatch(self) -> None: | ||||
if isinstance(self.device, int): | ||||
self.device = torch.device(self.device) | ||||
|
||||
def on_save(self, checkpoint: dict) -> dict: | ||||
return move_data_to_device(checkpoint, torch.device("cpu")) | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I believe this should use the checkpoint_io plugin. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
would it be possible to not have them in this directory, but dynamically write those files in the example?
Then users could simply copy paste the example code and also it is relying less on relative paths :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we keep them inside this repo, in fact, they are here, but the past is not updated...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My thought was to showcase a simple example on the usages. It would be nice to keep it along with this working file