@@ -242,20 +242,11 @@ def test_dataloaders_passed_to_fit(tmpdir):
242
242
assert trainer .state .finished , f"Training failed with { trainer .state } "
243
243
244
244
245
- @pytest .mark .parametrize (
246
- ["tpu_cores" , "expected_tpu_id" ],
247
- [(1 , None ), (8 , None ), ([1 ], 1 ), ([8 ], 8 )],
248
- )
249
245
@RunIf (tpu = True )
250
- def test_tpu_id_to_be_as_expected (tpu_cores , expected_tpu_id ):
251
- """Test if trainer.tpu_id is set as expected."""
252
- assert Trainer (tpu_cores = tpu_cores )._accelerator_connector .tpu_id == expected_tpu_id
253
-
254
-
255
- def test_tpu_misconfiguration ():
256
- """Test if trainer.tpu_id is set as expected."""
246
+ @pytest .mark .parametrize ("tpu_cores" , [[1 , 8 ], "9, " , [9 ], [0 ], 2 , 10 ])
247
+ def test_tpu_misconfiguration (tpu_cores ):
257
248
with pytest .raises (MisconfigurationException , match = "`tpu_cores` can only be" ):
258
- Trainer (tpu_cores = [ 1 , 8 ] )
249
+ Trainer (tpu_cores = tpu_cores )
259
250
260
251
261
252
@pytest .mark .skipif (_TPU_AVAILABLE , reason = "test requires missing TPU" )
@@ -289,33 +280,6 @@ def test_broadcast(rank):
289
280
xmp .spawn (test_broadcast , nprocs = 8 , start_method = "fork" )
290
281
291
282
292
- @pytest .mark .parametrize (
293
- ["tpu_cores" , "expected_tpu_id" , "error_expected" ],
294
- [
295
- (1 , None , False ),
296
- (8 , None , False ),
297
- ([1 ], 1 , False ),
298
- ([8 ], 8 , False ),
299
- ("1," , 1 , False ),
300
- ("1" , None , False ),
301
- ("9, " , 9 , True ),
302
- ([9 ], 9 , True ),
303
- ([0 ], 0 , True ),
304
- (2 , None , True ),
305
- (10 , None , True ),
306
- ],
307
- )
308
- @RunIf (tpu = True )
309
- @pl_multi_process_test
310
- def test_tpu_choice (tmpdir , tpu_cores , expected_tpu_id , error_expected ):
311
- if error_expected :
312
- with pytest .raises (MisconfigurationException , match = r".*tpu_cores` can only be 1, 8 or [<1-8>]*" ):
313
- Trainer (default_root_dir = tmpdir , tpu_cores = tpu_cores )
314
- else :
315
- trainer = Trainer (default_root_dir = tmpdir , tpu_cores = tpu_cores )
316
- assert trainer ._accelerator_connector .tpu_id == expected_tpu_id
317
-
318
-
319
283
@pytest .mark .parametrize (
320
284
["cli_args" , "expected" ],
321
285
[("--tpu_cores=8" , {"tpu_cores" : 8 }), ("--tpu_cores=1," , {"tpu_cores" : "1," })],
0 commit comments