Skip to content

Commit 91052dc

Browse files
authored
Move ipu precision flag check to IPUPrecisionPlugin init (#12148)
1 parent b5fe056 commit 91052dc

File tree

3 files changed

+14
-9
lines changed

3 files changed

+14
-9
lines changed

pytorch_lightning/plugins/precision/ipu.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,20 @@
2727

2828

2929
class IPUPrecisionPlugin(PrecisionPlugin):
30-
"""Precision plugin for IPU integration."""
30+
"""Precision plugin for IPU integration.
31+
32+
Raises:
33+
ValueError:
34+
If the precision is neither 16 nor 32.
35+
"""
3136

3237
def __init__(self, precision: int) -> None:
38+
supported_precision_values = (16, 32)
39+
if precision not in supported_precision_values:
40+
raise ValueError(
41+
f"`Trainer(accelerator='ipu', precision={precision!r})` is not supported."
42+
f" `precision` must be one of: {supported_precision_values}."
43+
)
3344
super().__init__()
3445
self.precision = precision
3546

pytorch_lightning/trainer/connectors/accelerator_connector.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -685,12 +685,6 @@ def _check_and_init_precision(self) -> PrecisionPlugin:
685685

686686
def _validate_precision_choice(self) -> None:
687687
"""Validate the combination of choices for precision, AMP type, and accelerator."""
688-
# TODO: change exception type to ImpactableConfigurationException
689-
if isinstance(self.accelerator, IPUAccelerator):
690-
if self._precision_flag not in (16, 32):
691-
raise MisconfigurationException(
692-
f"`Trainer(accelerator='ipu', precision={self._precision_flag!r})` is not supported."
693-
)
694688
if isinstance(self.accelerator, TPUAccelerator):
695689
if self._precision_flag == 64:
696690
raise MisconfigurationException(

tests/accelerators/test_accelerator_connector.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -929,9 +929,9 @@ def test_unsupported_ipu_choice(mock_ipu_acc_avail, monkeypatch):
929929

930930
monkeypatch.setattr(imports, "_IPU_AVAILABLE", True)
931931
monkeypatch.setattr(ipu, "_IPU_AVAILABLE", True)
932-
with pytest.raises(MisconfigurationException, match=r"accelerator='ipu', precision='bf16'\)` is not supported"):
932+
with pytest.raises(ValueError, match=r"accelerator='ipu', precision='bf16'\)` is not supported"):
933933
Trainer(accelerator="ipu", precision="bf16")
934-
with pytest.raises(MisconfigurationException, match=r"accelerator='ipu', precision=64\)` is not supported"):
934+
with pytest.raises(ValueError, match=r"accelerator='ipu', precision=64\)` is not supported"):
935935
Trainer(accelerator="ipu", precision=64)
936936

937937

0 commit comments

Comments
 (0)