@@ -418,11 +418,7 @@ class Prec(PrecisionPlugin):
418
418
class TrainTypePlugin (DDPPlugin ):
419
419
pass
420
420
421
- ttp = TrainTypePlugin (
422
- device = torch .device ("cpu" ),
423
- accelerator = Accel (),
424
- precision_plugin = Prec ()
425
- )
421
+ ttp = TrainTypePlugin (device = torch .device ("cpu" ), accelerator = Accel (), precision_plugin = Prec ())
426
422
trainer = Trainer (strategy = ttp , fast_dev_run = True , num_processes = 2 )
427
423
assert isinstance (trainer .accelerator , Accel )
428
424
assert isinstance (trainer .training_type_plugin , TrainTypePlugin )
@@ -1038,10 +1034,13 @@ def test_unsupported_tpu_choice(monkeypatch):
1038
1034
with pytest .raises (MisconfigurationException , match = r"accelerator='tpu', precision=64\)` is not implemented" ):
1039
1035
Trainer (accelerator = "tpu" , precision = 64 )
1040
1036
1041
- with pytest .warns (UserWarning , match = r"accelerator='tpu', precision=16\)` but native AMP is not supported" ):
1042
- Trainer (accelerator = "tpu" , precision = 16 )
1043
- with pytest .warns (UserWarning , match = r"accelerator='tpu', precision=16\)` but apex AMP is not supported" ):
1044
- Trainer (accelerator = "tpu" , precision = 16 , amp_backend = "apex" )
1037
+ with pytest .raises (ValueError , match = "TPUAccelerator` can only be used with a `SingleTPUPlugin`" ):
1038
+ with pytest .warns (UserWarning , match = r"accelerator='tpu', precision=16\)` but native AMP is not supported" ):
1039
+ Trainer (accelerator = "tpu" , precision = 16 )
1040
+
1041
+ with pytest .raises (ValueError , match = "TPUAccelerator` can only be used with a `SingleTPUPlugin`" ):
1042
+ with pytest .warns (UserWarning , match = r"accelerator='tpu', precision=16\)` but apex AMP is not supported" ):
1043
+ Trainer (accelerator = "tpu" , precision = 16 , amp_backend = "apex" )
1045
1044
1046
1045
1047
1046
def test_unsupported_ipu_choice (monkeypatch ):
0 commit comments