Skip to content

Commit f245ea7

Browse files
kaushikb11lexierule
authored andcommitted
Don't import torch_xla.debug for torch-xla<1.8 (#10836)
1 parent d2e791e commit f245ea7

File tree

3 files changed

+12
-5
lines changed

3 files changed

+12
-5
lines changed

CHANGELOG.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1717
- Fixed early schedule reset logic in PyTorch profiler that was causing data leak ([#10837](https://github.com/PyTorchLightning/pytorch-lightning/pull/10837))
1818
- Fixed a bug that caused incorrect batch indices to be passed to the `BasePredictionWriter` hooks when using a dataloader with `num_workers > 0` ([#10870](https://github.com/PyTorchLightning/pytorch-lightning/pull/10870))
1919
- Fixed an issue with item assignment on the logger on rank > 0 for those who support it ([#10917](https://github.com/PyTorchLightning/pytorch-lightning/pull/10917))
20-
20+
- Fixed importing `torch_xla.debug` for `torch-xla<1.8` ([#10836](https://github.com/PyTorchLightning/pytorch-lightning/pull/10836))
2121

2222

2323
## [1.5.4] - 2021-11-30

pytorch_lightning/profiler/xla.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,9 +42,10 @@
4242
from typing import Dict
4343

4444
from pytorch_lightning.profiler.base import BaseProfiler
45-
from pytorch_lightning.utilities import _TPU_AVAILABLE
45+
from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_8, _TPU_AVAILABLE
46+
from pytorch_lightning.utilities.exceptions import MisconfigurationException
4647

47-
if _TPU_AVAILABLE:
48+
if _TPU_AVAILABLE and _TORCH_GREATER_EQUAL_1_8:
4849
import torch_xla.debug.profiler as xp
4950

5051
log = logging.getLogger(__name__)
@@ -65,6 +66,10 @@ class XLAProfiler(BaseProfiler):
6566
def __init__(self, port: int = 9012) -> None:
6667
"""This Profiler will help you debug and optimize training workload performance for your models using Cloud
6768
TPU performance tools."""
69+
if not _TPU_AVAILABLE:
70+
raise MisconfigurationException("`XLAProfiler` is only supported on TPUs")
71+
if not _TORCH_GREATER_EQUAL_1_8:
72+
raise MisconfigurationException("`XLAProfiler` is only supported with `torch-xla >= 1.8`")
6873
super().__init__(dirpath=None, filename=None)
6974
self.port = port
7075
self._recording_map: Dict = {}

tests/profiler/test_xla_profiler.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,16 @@
1818

1919
from pytorch_lightning import Trainer
2020
from pytorch_lightning.profiler import XLAProfiler
21-
from pytorch_lightning.utilities import _TPU_AVAILABLE
21+
from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_8, _TPU_AVAILABLE
2222
from tests.helpers import BoringModel
2323
from tests.helpers.runif import RunIf
2424

2525
if _TPU_AVAILABLE:
26-
import torch_xla.debug.profiler as xp
2726
import torch_xla.utils.utils as xu
2827

28+
if _TORCH_GREATER_EQUAL_1_8:
29+
import torch_xla.debug.profiler as xp
30+
2931

3032
@RunIf(tpu=True)
3133
def test_xla_profiler_instance(tmpdir):

0 commit comments

Comments
 (0)