Skip to content

Add support for Habana accelerator (HPU) #11808

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 183 commits into from
Mar 25, 2022
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
Show all changes
183 commits
Select commit Hold shift + click to select a range
f7175c4
Add hpu accelerator support
jerome-habana Feb 8, 2022
7fb871b
Update strategy for optimizer usage
jerome-habana Feb 8, 2022
a1a1ca9
Add checkpointing support
jerome-habana Feb 8, 2022
9a6da43
Fix distributed support with hpu
jerome-habana Feb 8, 2022
3e76db9
Enable usage of static_graph with hpu
jerome-habana Feb 8, 2022
b43d226
Add HPU tests
jerome-habana Feb 8, 2022
992093d
Add basic hpu_stats monitor
jerome-habana Feb 8, 2022
943be49
Code cleanup
jerome-habana Feb 8, 2022
3015972
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 8, 2022
257d644
Update tests
jerome-habana Feb 9, 2022
f1867cd
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 9, 2022
c61d68b
Add configurable params for tests
jerome-habana Feb 10, 2022
f74a898
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 10, 2022
963cd1e
Enable inference test
jerome-habana Feb 11, 2022
53a5416
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 11, 2022
2de04e8
Resolve issue with hmp params type and load hpu
jerome-habana Feb 15, 2022
0197b9c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 15, 2022
b412638
Move hmp_params to HPUPrecision plugin
jerome-habana Feb 17, 2022
e549434
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 17, 2022
1cc0a37
Update habana distributed with ddp subclass
jerome-habana Feb 18, 2022
aeda681
Add hpu backend, datatype checks
jerome-habana Feb 18, 2022
fe32865
Merge branch 'master' into hpu_accelerator
jerome-habana Feb 23, 2022
f9b0c5f
Merge branch 'master' into hpu_accelerator
jerome-habana Feb 23, 2022
123112d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 23, 2022
ede68eb
Remove unused param for 'on_train_batch_end' in hpu test
jerome-habana Feb 23, 2022
262343a
Merge branch 'master' into hpu_accelerator
jerome-habana Mar 3, 2022
3a029c1
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 3, 2022
0a959f0
Addres review comments
jerome-habana Mar 3, 2022
1434299
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 3, 2022
400ea77
Address review comments
jerome-habana Mar 4, 2022
4146bab
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 4, 2022
f5cb696
remove deprecated logging
jerome-habana Mar 4, 2022
d3cd6b1
Merge branch 'master' into hpu_accelerator
jerome-habana Mar 7, 2022
448ed77
Fix imports for failing CI
kaushikb11 Mar 9, 2022
10b190f
fix str to_device section in converting.rst (#12243)
awaelchli Mar 7, 2022
c17c62b
Disable tuner with distributed strategies (#12179)
rohitgr7 Mar 7, 2022
28bc4f0
Add callout items to the Docs landing page (#12196)
kaushikb11 Mar 7, 2022
97e1d28
Integrate global step with progress tracking (#11805)
carmocca Mar 7, 2022
5aecf65
Deprecate `LightningDataModule.on_save/load_checkpoint` (#11893)
jjenniferdai Mar 8, 2022
0949599
add Azure HPU agent (#12258)
Borda Mar 8, 2022
4bd5034
Add `LightningCLI(auto_registry)` (#12108)
carmocca Mar 8, 2022
bd76456
Drop PyTorch 1.7 testing from the CI (#12191)
krshrimali Mar 8, 2022
80b8d01
Have the outputs match the loops format (#12182)
carmocca Mar 8, 2022
c168db5
Address review comments
jerome-habana Mar 9, 2022
831a672
Review comment :Make use of Boring model
jerome-habana Mar 9, 2022
328329e
Update stats example trainer params
jerome-habana Mar 9, 2022
c8e331e
Correct flake8 errors
jerome-habana Mar 9, 2022
9a71bdc
Remove docstring examples
jerome-habana Mar 9, 2022
8efed0b
Update hpu-tests.yml
raoakarsha Mar 3, 2022
90409a2
prune
Borda Mar 7, 2022
5bbc6dc
Update hpu-tests.yml
Borda Mar 8, 2022
85f535b
Apply suggestions from code review
Borda Mar 9, 2022
75227d9
hwinfo
Borda Mar 9, 2022
711bbf3
Override mypy warnings
jerome-habana Mar 10, 2022
bc174f6
Update test and requirements file
jerome-habana Mar 10, 2022
b28c0ce
Remove hpu stats monitor and deprecated API's
jerome-habana Mar 10, 2022
3c08bf5
Update non-hpu tests
jerome-habana Mar 10, 2022
f857721
Add hpu-tests.yml and run_hpu_tests.py to support HPU Testing
Borda Mar 10, 2022
a2b2cb1
Merge branch 'master' into hpu_accelerator
jerome-habana Mar 10, 2022
7cb34bc
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 10, 2022
f6baf69
Add exception for non-hpu tests
jerome-habana Mar 10, 2022
21fc9a4
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 10, 2022
3665ffc
Throw exception when accelerator is not present
jerome-habana Mar 10, 2022
e0b4611
Resolve mypy and error message
jerome-habana Mar 10, 2022
545ab6a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 10, 2022
96ed1cd
Disable hpu pl examples on CPU
jerome-habana Mar 10, 2022
c44b017
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 10, 2022
410875c
Address review comments
jerome-habana Mar 14, 2022
8efe56f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 14, 2022
073b170
Add documentation for habana gaudi accelerator (HPU)
jerome-habana Mar 15, 2022
7bdcaf6
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 15, 2022
da1037a
Update test code syntax
jerome-habana Mar 15, 2022
5e7af01
Mitigate duplicate label error
jerome-habana Mar 15, 2022
70d6993
Add hpu to toctree
jerome-habana Mar 16, 2022
5061d71
Update pytorch_lightning/plugins/precision/hpu_precision.py
kaushikb11 Mar 16, 2022
f6c36ce
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 16, 2022
798f137
Update _broadvast_object_list
kaushikb11 Mar 16, 2022
5e098cb
Update broadcast for HPUParallelStrategy
kaushikb11 Mar 16, 2022
093056c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 16, 2022
0563310
Update reference links
kaushikb11 Mar 17, 2022
65886ba
Update Strategies
kaushikb11 Mar 17, 2022
d837ef3
Address reviews
kaushikb11 Mar 17, 2022
37e0000
Address reviews
kaushikb11 Mar 17, 2022
07c60b4
Address reviews
jerome-habana Mar 18, 2022
394d9e2
Merge branch 'master' into hpu_accelerator
jerome-habana Mar 18, 2022
12dc3ca
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 18, 2022
3064544
Remove too many sections from sidebar
akihironitta Mar 19, 2022
7c7721d
Fix invalid formatting and links
akihironitta Mar 19, 2022
cc71c7a
Merge branch 'master' into hpu_accelerator
kaushikb11 Mar 21, 2022
e6eaa9f
Address reviews for HPUCHeckpointIO
kaushikb11 Mar 21, 2022
33beabd
Address reviews for HPU + AcceleratorConnector
kaushikb11 Mar 21, 2022
759804e
Fix tests
kaushikb11 Mar 21, 2022
bda7e36
Address reviews
kaushikb11 Mar 21, 2022
bdc19be
Remove setting hpu accelerator by just strategy
kaushikb11 Mar 21, 2022
2d34cc5
Remove unnecessary properties for HPU
kaushikb11 Mar 21, 2022
c32601a
Fix HPU tests
kaushikb11 Mar 21, 2022
f43750e
Move tests
kaushikb11 Mar 21, 2022
4e09286
Improve docs
kaushikb11 Mar 21, 2022
ab2f595
Improve tests
kaushikb11 Mar 21, 2022
549d784
Update Changelog
kaushikb11 Mar 21, 2022
ec929df
Fix test for the rigth device type
kaushikb11 Mar 21, 2022
c55a82f
Fix tests
kaushikb11 Mar 21, 2022
05dcc1c
Fix tests
kaushikb11 Mar 21, 2022
150e667
Merge branch 'master' into hpu_accelerator
kaushikb11 Mar 21, 2022
f5a333b
Address reviews
kaushikb11 Mar 21, 2022
57b9c24
Update plugins
kaushikb11 Mar 21, 2022
3dd763c
Update docs/source/accelerators/hpu.rst
kaushikb11 Mar 22, 2022
773a7a0
Update HPU mnist example
kaushikb11 Mar 22, 2022
9378c87
Update strategy
kaushikb11 Mar 22, 2022
9aefcd2
Address reviews
jerome-habana Mar 22, 2022
1f0b187
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 22, 2022
1d30ef9
Add precision tests to azure pipeline
jerome-habana Mar 22, 2022
fd9488f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 22, 2022
a4f79fb
Add comments
kaushikb11 Mar 22, 2022
a6a336d
Fix argparse
kaushikb11 Mar 22, 2022
dca30ee
Remove unnecessary use of PL_TORCH_DISTRIBUTED_BACKEND env variable
kaushikb11 Mar 22, 2022
bb8984f
Update pytorch_lightning/strategies/hpu_parallel.py
kaushikb11 Mar 22, 2022
4ab35db
Update pytorch_lightning/utilities/distributed.py
kaushikb11 Mar 22, 2022
e65a3fb
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 22, 2022
a517942
Address review
jerome-habana Mar 23, 2022
d89815d
Address reviews
kaushikb11 Mar 23, 2022
0238b45
Update document
jerome-habana Mar 23, 2022
4f44ea9
Improve Habana doc
kaushikb11 Mar 23, 2022
f332e1c
Improve Habana doc
kaushikb11 Mar 23, 2022
81202c6
Improve Habana doc
kaushikb11 Mar 23, 2022
503df4e
Update pytorch_lightning/trainer/connectors/accelerator_connector.py
kaushikb11 Mar 23, 2022
e6af417
Update links
kaushikb11 Mar 23, 2022
2bd4a66
Merge branch 'hpu_accelerator' of https://github.com/jerome-habana/py…
kaushikb11 Mar 23, 2022
67e710e
Update precision sections
kaushikb11 Mar 23, 2022
1df801b
Update doc
kaushikb11 Mar 23, 2022
9152114
Add defaults to hmp_params for Precision Plugin
kaushikb11 Mar 23, 2022
9846b6a
Update .azure-pipelines/run_hpu_tests.py
kaushikb11 Mar 24, 2022
e86becf
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 24, 2022
d165c44
Apply suggestions from code review
kaushikb11 Mar 24, 2022
c76b95f
Update docs/source/accelerators/hpu.rst
kaushikb11 Mar 24, 2022
bafcb8d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 24, 2022
2d6c6dd
Apply suggestions from code review
kaushikb11 Mar 24, 2022
75728b6
Apply suggestions from code review
kaushikb11 Mar 24, 2022
68c5281
Update docs/source/accelerators/hpu.rst
kaushikb11 Mar 24, 2022
600e1bd
Address reviews
kaushikb11 Mar 24, 2022
b03d079
Apply suggestions from code review
kaushikb11 Mar 24, 2022
6e4474e
Update API references
kaushikb11 Mar 24, 2022
efd9f65
Address reviews regarding precision
kaushikb11 Mar 24, 2022
22827f0
Address reviews regarding docs and precision
kaushikb11 Mar 24, 2022
e82544c
Update docs/source/accelerators/hpu.rst
kaushikb11 Mar 24, 2022
4500a7e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 24, 2022
98ba21f
Apply suggestions from code review
kaushikb11 Mar 24, 2022
3c10359
Address reviews & update tests
kaushikb11 Mar 24, 2022
6c0dd88
Merge branch 'hpu_accelerator' of https://github.com/jerome-habana/py…
kaushikb11 Mar 24, 2022
e137f19
Update testing pipeline & conftest
kaushikb11 Mar 24, 2022
a62cfa1
Fix ci
kaushikb11 Mar 24, 2022
1078a69
Add device parsing logic for HPUs
kaushikb11 Mar 24, 2022
a9dfcf3
Fix device parsing
kaushikb11 Mar 24, 2022
4665101
Use the CLI in the example
Mar 24, 2022
2ee4bbf
Docs
Mar 24, 2022
e9ae312
Merge branch 'master' into hpu_accelerator
kaushikb11 Mar 24, 2022
dc3eca7
Update docs/source/accelerators/hpu.rst
kaushikb11 Mar 24, 2022
6952125
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 24, 2022
91cced3
Update hmp_params
kaushikb11 Mar 24, 2022
0671d2c
Support passing amp_level to HPUPrecision
kaushikb11 Mar 24, 2022
522106e
Update HPUAccelerator
kaushikb11 Mar 24, 2022
c8b89ea
Update tests
kaushikb11 Mar 25, 2022
7d028b1
Fix precision tests
kaushikb11 Mar 25, 2022
3c86aff
Update device parsing logic
kaushikb11 Mar 25, 2022
3c8e321
Fix tests & address reviews
kaushikb11 Mar 25, 2022
dcda0ac
Update run_hpu_tests
kaushikb11 Mar 25, 2022
e254cd0
Update CLI test
jerome-habana Mar 25, 2022
c452bd2
Fix typing
kaushikb11 Mar 25, 2022
4c51b33
Merge branch 'hpu_accelerator' of https://github.com/jerome-habana/py…
kaushikb11 Mar 25, 2022
b66c867
Merge branch 'master' into hpu_accelerator
jerome-habana Mar 25, 2022
dca6b0f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 25, 2022
98e901d
Enable example test in pipeline
jerome-habana Mar 25, 2022
2860a4e
export path of modules
jerome-habana Mar 25, 2022
a297593
Fix test
kaushikb11 Mar 25, 2022
9c1fff7
Merge branch 'hpu_accelerator' of https://github.com/jerome-habana/py…
kaushikb11 Mar 25, 2022
65f1fb9
Update torch distributed
kaushikb11 Mar 25, 2022
2380887
Update strategy
kaushikb11 Mar 25, 2022
59ef6fd
Update example
kaushikb11 Mar 25, 2022
c02c1ed
Apply suggestions from code review
kaushikb11 Mar 25, 2022
beda30c
Address reviews
kaushikb11 Mar 25, 2022
eb99e52
Merge branch 'hpu_accelerator' of https://github.com/jerome-habana/py…
kaushikb11 Mar 25, 2022
c465a06
Update backend env variable for strategy
kaushikb11 Mar 25, 2022
60f2da4
Update backend env variable for strategy
kaushikb11 Mar 25, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 62 additions & 0 deletions pl_examples/hpu_examples/simple_mnist/mnist.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import os
import sys

import habana_frameworks.torch.core as htcore
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader, random_split
from torchvision import transforms
from torchvision.datasets import MNIST

import pytorch_lightning as pl
from pytorch_lightning.callbacks import HPUStatsMonitor


class MNISTModel(pl.LightningModule):
def __init__(self):
super().__init__()
self.l1 = torch.nn.Linear(28 * 28, 10)

def forward(self, x):
return torch.relu(self.l1(x.view(x.size(0), -1)))

def training_step(self, batch, batch_nb):
x, y = batch
loss = F.cross_entropy(self(x), y)
return loss

def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=0.02)


# Init our model
mnist_model = MNISTModel()

# Init DataLoader from MNIST Dataset
train_ds = MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor())
train_loader = DataLoader(train_ds, batch_size=32)

# TBD: import these keys from hmp
hmp_keys = ["level", "verbose", "bf16_ops", "fp32_ops"]
hmp_params = dict.fromkeys(hmp_keys)
hmp_params["level"] = "O1"
hmp_params["verbose"] = False
hmp_params["bf16_ops"] = "./pl_examples/hpu_examples/simple_mnist/ops_bf16_mnist.txt"
hmp_params["fp32_ops"] = "./pl_examples/hpu_examples/simple_mnist/ops_fp32_mnist.txt"

hpu_stats = HPUStatsMonitor(log_save_dir="habana_ptl_log", exp_name="mnist")

# Initialize a trainer
trainer = pl.Trainer(
devices=1,
callbacks=[hpu_stats],
max_epochs=1,
precision=32,
hmp_params=hmp_params,
default_root_dir="/tmp/",
accelerator="hpu",
)

# Train the model ⚡
trainer.fit(mnist_model, train_loader)
2 changes: 2 additions & 0 deletions pl_examples/hpu_examples/simple_mnist/ops_bf16_mnist.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
linear
relu
1 change: 1 addition & 0 deletions pl_examples/hpu_examples/simple_mnist/ops_fp32_mnist.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
cross_entropy
1 change: 1 addition & 0 deletions pytorch_lightning/accelerators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,5 +13,6 @@
from pytorch_lightning.accelerators.accelerator import Accelerator # noqa: F401
from pytorch_lightning.accelerators.cpu import CPUAccelerator # noqa: F401
from pytorch_lightning.accelerators.gpu import GPUAccelerator # noqa: F401
from pytorch_lightning.accelerators.hpu import HPUAccelerator # noqa: F401
from pytorch_lightning.accelerators.ipu import IPUAccelerator # noqa: F401
from pytorch_lightning.accelerators.tpu import TPUAccelerator # noqa: F401
1 change: 1 addition & 0 deletions pytorch_lightning/accelerators/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ class Accelerator(ABC):
- GPU
- TPU
- IPU
- HPU
"""

def setup_environment(self, root_device: torch.device) -> None:
Expand Down
33 changes: 33 additions & 0 deletions pytorch_lightning/accelerators/hpu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Dict, Union

import torch

from pytorch_lightning.accelerators.accelerator import Accelerator


class HPUAccelerator(Accelerator):
"""Accelerator for HPU devices."""

def get_device_stats(self, device: Union[str, torch.device]) -> Dict[str, Any]:
"""HPU device stats aren't supported yet."""
return {}

@staticmethod
def auto_device_count() -> int:
"""Get the devices when set to auto."""
# TBD: make this configurable
return 8
2 changes: 2 additions & 0 deletions pytorch_lightning/callbacks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from pytorch_lightning.callbacks.finetuning import BackboneFinetuning, BaseFinetuning
from pytorch_lightning.callbacks.gpu_stats_monitor import GPUStatsMonitor
from pytorch_lightning.callbacks.gradient_accumulation_scheduler import GradientAccumulationScheduler
from pytorch_lightning.callbacks.hpu_stats_monitor import HPUStatsMonitor
from pytorch_lightning.callbacks.lambda_function import LambdaCallback
from pytorch_lightning.callbacks.lr_monitor import LearningRateMonitor
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
Expand All @@ -37,6 +38,7 @@
"DeviceStatsMonitor",
"EarlyStopping",
"GPUStatsMonitor",
"HPUStatsMonitor",
"XLAStatsMonitor",
"GradientAccumulationScheduler",
"LambdaCallback",
Expand Down
80 changes: 80 additions & 0 deletions pytorch_lightning/callbacks/hpu_stats_monitor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
# Copyright (C) 2021 Habana Labs, Ltd. an Intel Company
# All Rights Reserved.
#
# Unauthorized copying of this file or any element(s) within it, via any medium
# is strictly prohibited.
# This file contains Habana Labs, Ltd. proprietary and confidential information
# and is subject to the confidentiality and license agreements under which it
# was provided.
#

# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
hpu Stats Monitor
=================

Monitor and logs hpu stats during training.

"""
from typing import Any, Dict, List, Optional, Tuple

import torch

import pytorch_lightning as pl
from pytorch_lightning.callbacks.base import Callback
from pytorch_lightning.utilities import rank_zero_only


class HPUStatsMonitor(Callback):
"""Automatically monitors and logs hpu stats during training stage.

Args:
save_dir: directory to save the logs.
exp_name: name of the experiment.

Example::

>>> from pytorch_lightning import Trainer
>>> from pytorch_lightning.callbacks import HPUStatsMonitor
>>> hpu_stats = HPUStatsMonitor()
>>> trainer = Trainer(hpus=1, callbacks=[hpu_stats])

you can also optionally provide save_dir and exp_name in HPUStatsMonitor.
No need to provide logger in Trainer.
"""

def __init__(self, log_save_dir: str = "habana_ptl_logs", exp_name: str = "default"):
super().__init__()
self.log_save_dir = log_save_dir
self.exp_name = exp_name

def on_init_end(self, trainer: "pl.Trainer") -> None:
from pytorch_lightning import loggers as pl_logger

self.tb_logger = pl_logger.TensorBoardLogger(save_dir=self.log_save_dir, name=self.exp_name)
trainer.logger = self.tb_logger

def on_before_backward(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", loss: torch.Tensor) -> None:
pl_module.log("Model_Loss", loss, on_step=True, on_epoch=True, enable_graph=False, logger=True)

def on_train_epoch_end(
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", unused: Optional = None
) -> None:
tensor_board = trainer.logger.experiment
dict = vars(pl_module)
modules = dict["_modules"]
for module_name in modules:
tensor_board.add_histogram(module_name + ".weight", modules[module_name].weight, pl_module.current_epoch)
tensor_board.add_histogram(module_name + ".bias", modules[module_name].bias, pl_module.current_epoch)
8 changes: 8 additions & 0 deletions pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,14 @@ def on_gpu(self):
"""
return self.device.type == "cuda"

@property
def on_hpu(self):
"""True if your model is currently running on HPUs.

Useful to set flags around the LightningModule for different CPU vs GPU vs HPU behavior.
"""
return self.device.type == "hpu"

@property
def automatic_optimization(self) -> bool:
"""If set to ``False`` you are responsible for calling ``.backward()``, ``.step()``, ``.zero_grad()``."""
Expand Down
2 changes: 2 additions & 0 deletions pytorch_lightning/lite/lite.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ def __init__(
devices=devices,
tpu_cores=tpu_cores,
ipus=None,
hpus=None,
accelerator=accelerator,
strategy=strategy,
gpus=gpus,
Expand All @@ -98,6 +99,7 @@ def __init__(
precision=precision,
amp_type="native",
amp_level=None,
hmp_params=None,
plugins=plugins,
)
self._strategy = self._accelerator_connector.strategy
Expand Down
Loading