Skip to content

Commit 2e35731

Browse files
committed
fix tpu error check
1 parent 9dc0962 commit 2e35731

File tree

1 file changed

+8
-9
lines changed

1 file changed

+8
-9
lines changed

tests/accelerators/test_accelerator_connector.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -418,11 +418,7 @@ class Prec(PrecisionPlugin):
418418
class TrainTypePlugin(DDPPlugin):
419419
pass
420420

421-
ttp = TrainTypePlugin(
422-
device=torch.device("cpu"),
423-
accelerator=Accel(),
424-
precision_plugin=Prec()
425-
)
421+
ttp = TrainTypePlugin(device=torch.device("cpu"), accelerator=Accel(), precision_plugin=Prec())
426422
trainer = Trainer(strategy=ttp, fast_dev_run=True, num_processes=2)
427423
assert isinstance(trainer.accelerator, Accel)
428424
assert isinstance(trainer.training_type_plugin, TrainTypePlugin)
@@ -1038,10 +1034,13 @@ def test_unsupported_tpu_choice(monkeypatch):
10381034
with pytest.raises(MisconfigurationException, match=r"accelerator='tpu', precision=64\)` is not implemented"):
10391035
Trainer(accelerator="tpu", precision=64)
10401036

1041-
with pytest.warns(UserWarning, match=r"accelerator='tpu', precision=16\)` but native AMP is not supported"):
1042-
Trainer(accelerator="tpu", precision=16)
1043-
with pytest.warns(UserWarning, match=r"accelerator='tpu', precision=16\)` but apex AMP is not supported"):
1044-
Trainer(accelerator="tpu", precision=16, amp_backend="apex")
1037+
with pytest.raises(ValueError, match="TPUAccelerator` can only be used with a `SingleTPUPlugin`"):
1038+
with pytest.warns(UserWarning, match=r"accelerator='tpu', precision=16\)` but native AMP is not supported"):
1039+
Trainer(accelerator="tpu", precision=16)
1040+
1041+
with pytest.raises(ValueError, match="TPUAccelerator` can only be used with a `SingleTPUPlugin`"):
1042+
with pytest.warns(UserWarning, match=r"accelerator='tpu', precision=16\)` but apex AMP is not supported"):
1043+
Trainer(accelerator="tpu", precision=16, amp_backend="apex")
10451044

10461045

10471046
def test_unsupported_ipu_choice(monkeypatch):

0 commit comments

Comments
 (0)