Skip to content

Commit 7901092

Browse files
committed
fix tpu error check
1 parent 9dc0962 commit 7901092

File tree

2 files changed

+27
-22
lines changed

2 files changed

+27
-22
lines changed

pytorch_lightning/trainer/connectors/accelerator_connector.py

Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,7 @@ def __init__(
177177
self._training_type_plugin_resolved = False
178178
self.training_type_plugin = self.final_training_type_plugin()
179179
self.accelerator = self.training_type_plugin.accelerator
180+
self.precision_plugin = self.training_type_plugin._precision_plugin
180181

181182
self._check_tpu_mis_config()
182183

@@ -391,8 +392,7 @@ def handle_given_plugins(self) -> None:
391392
def accelerator_types(self) -> List[str]:
392393
return ["auto"] + list(_AcceleratorType)
393394

394-
@property
395-
def precision_plugin(self) -> PrecisionPlugin:
395+
def final_precision_plugin(self) -> PrecisionPlugin:
396396
if self._precision_plugin is None:
397397
self._precision_plugin = self.select_precision_plugin()
398398
return self._precision_plugin
@@ -408,16 +408,28 @@ def final_training_type_plugin(self) -> TrainingTypePlugin:
408408
if self._checkpoint_io is not None:
409409
self._training_type_plugin.checkpoint_io = self._checkpoint_io
410410
if (
411-
(hasattr(self.strategy, "precision_plugin") and self.precision_plugin is None)
412-
or not hasattr(self.strategy, "precision_plugin")
411+
# handle custom strategy with custom precision
412+
(
413+
isinstance(self.strategy, TrainingTypePlugin) and (
414+
self.strategy.precision_plugin is None
415+
or not isinstance(self.strategy.precision_plugin, PrecisionPlugin)
416+
)
417+
)
418+
or not isinstance(self.strategy, TrainingTypePlugin)
413419
):
414-
precision_plugin = self.precision_plugin
420+
precision_plugin = self.final_precision_plugin()
415421
if precision_plugin is not None:
416422
self._training_type_plugin._precision_plugin = precision_plugin
417423
self._training_type_plugin_resolved = True
418424
if (
419-
(hasattr(self.strategy, "accelerator") and self.strategy.accelerator is None)
420-
or not hasattr(self.strategy, "accelerator")
425+
# handle custom strategy with custom accelerator
426+
(
427+
isinstance(self.strategy, TrainingTypePlugin) and (
428+
self.strategy.accelerator is None
429+
or not isinstance(self.strategy.accelerator, Accelerator)
430+
)
431+
)
432+
or not isinstance(self.strategy, TrainingTypePlugin)
421433
):
422434
self._training_type_plugin.accelerator = self.select_accelerator()
423435
return self._training_type_plugin
@@ -790,12 +802,6 @@ def resolve_training_type_plugin(self, training_type: TrainingTypePlugin) -> Tra
790802
def select_accelerator(self) -> Accelerator:
791803
if isinstance(self.distributed_backend, Accelerator):
792804
# custom accelerator from user
793-
if self._precision_plugin is not None or self._training_type_plugin is not None:
794-
# plugins also specified by user
795-
rank_zero_warn(
796-
"Specified `Precision` and `TrainingType` plugins will be ignored,"
797-
" since an `Accelerator` instance was provided."
798-
)
799805
return self.distributed_backend
800806

801807
if self.use_gpu:

tests/accelerators/test_accelerator_connector.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -418,11 +418,7 @@ class Prec(PrecisionPlugin):
418418
class TrainTypePlugin(DDPPlugin):
419419
pass
420420

421-
ttp = TrainTypePlugin(
422-
device=torch.device("cpu"),
423-
accelerator=Accel(),
424-
precision_plugin=Prec()
425-
)
421+
ttp = TrainTypePlugin(device=torch.device("cpu"), accelerator=Accel(), precision_plugin=Prec())
426422
trainer = Trainer(strategy=ttp, fast_dev_run=True, num_processes=2)
427423
assert isinstance(trainer.accelerator, Accel)
428424
assert isinstance(trainer.training_type_plugin, TrainTypePlugin)
@@ -1038,10 +1034,13 @@ def test_unsupported_tpu_choice(monkeypatch):
10381034
with pytest.raises(MisconfigurationException, match=r"accelerator='tpu', precision=64\)` is not implemented"):
10391035
Trainer(accelerator="tpu", precision=64)
10401036

1041-
with pytest.warns(UserWarning, match=r"accelerator='tpu', precision=16\)` but native AMP is not supported"):
1042-
Trainer(accelerator="tpu", precision=16)
1043-
with pytest.warns(UserWarning, match=r"accelerator='tpu', precision=16\)` but apex AMP is not supported"):
1044-
Trainer(accelerator="tpu", precision=16, amp_backend="apex")
1037+
with pytest.raises(ValueError, match="TPUAccelerator` can only be used with a `SingleTPUPlugin`"):
1038+
with pytest.warns(UserWarning, match=r"accelerator='tpu', precision=16\)` but native AMP is not supported"):
1039+
Trainer(accelerator="tpu", precision=16)
1040+
1041+
with pytest.raises(ValueError, match="TPUAccelerator` can only be used with a `SingleTPUPlugin`"):
1042+
with pytest.warns(UserWarning, match=r"accelerator='tpu', precision=16\)` but apex AMP is not supported"):
1043+
Trainer(accelerator="tpu", precision=16, amp_backend="apex")
10451044

10461045

10471046
def test_unsupported_ipu_choice(monkeypatch):

0 commit comments

Comments
 (0)