Skip to content

Commit 52221c0

Browse files
carmoccarohitgr7
authored andcommitted
Validate the precision input earlier (Lightning-AI#9763)
1 parent 12ac06b commit 52221c0

File tree

4 files changed

+7
-6
lines changed

4 files changed

+7
-6
lines changed

pytorch_lightning/trainer/connectors/accelerator_connector.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,10 @@ def __init__(
133133
self.sync_batchnorm = sync_batchnorm
134134
self.benchmark = benchmark
135135
self.replace_sampler_ddp = replace_sampler_ddp
136+
if not PrecisionType.supported_type(precision):
137+
raise MisconfigurationException(
138+
f"Precision {repr(precision)} is invalid. Allowed precision values: {PrecisionType.supported_types()}"
139+
)
136140
self.precision = precision
137141
self.amp_type = amp_type.lower() if isinstance(amp_type, str) else None
138142
self.amp_level = amp_level
@@ -662,10 +666,6 @@ def select_precision_plugin(self) -> PrecisionPlugin:
662666

663667
return ApexMixedPrecisionPlugin(self.amp_level)
664668

665-
raise MisconfigurationException(
666-
f"Precision {self.precision} is invalid. Allowed precision values: {PrecisionType.supported_types()}"
667-
)
668-
669669
def select_training_type_plugin(self) -> TrainingTypePlugin:
670670
if isinstance(self.accelerator, Accelerator) and self.accelerator.training_type_plugin is not None:
671671
plugin = self.accelerator.training_type_plugin

pytorch_lightning/utilities/enums.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ class PrecisionType(LightningEnum):
6262
FLOAT = "32"
6363
FULL = "64"
6464
BFLOAT = "bf16"
65+
MIXED = "mixed"
6566

6667
@staticmethod
6768
def supported_type(precision: Union[str, int]) -> bool:

tests/accelerators/test_accelerator_connector.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -709,7 +709,7 @@ def test_device_type_when_training_plugin_gpu_passed(tmpdir, plugin):
709709
@pytest.mark.parametrize("precision", [1, 12, "invalid"])
710710
def test_validate_precision_type(tmpdir, precision):
711711

712-
with pytest.raises(MisconfigurationException, match=f"Precision {precision} is invalid"):
712+
with pytest.raises(MisconfigurationException, match=f"Precision {repr(precision)} is invalid"):
713713
Trainer(precision=precision)
714714

715715

tests/utilities/test_enums.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ def test_consistency():
2727

2828

2929
def test_precision_supported_types():
30-
assert PrecisionType.supported_types() == ["16", "32", "64", "bf16"]
30+
assert PrecisionType.supported_types() == ["16", "32", "64", "bf16", "mixed"]
3131
assert PrecisionType.supported_type(16)
3232
assert PrecisionType.supported_type("16")
3333
assert not PrecisionType.supported_type(1)

0 commit comments

Comments
 (0)