@@ -550,22 +550,24 @@ def test_accelerator_cpu_with_num_processes_priority():
550
550
551
551
@RunIf (min_gpus = 2 )
552
552
@pytest .mark .parametrize (
553
- ["devices" , "plugin" ], [(1 , SingleDeviceStrategy ), ([1 ], SingleDeviceStrategy ), (2 , DDPSpawnStrategy )]
553
+ ["devices" , "plugin" , "num_devices" , "device_ids" ],
554
+ [(1 , SingleDeviceStrategy , 1 , [0 ]), ([1 ], SingleDeviceStrategy , 1 , [1 ]), (2 , DDPSpawnStrategy , 2 , [0 , 1 ])],
554
555
)
555
- def test_accelerator_gpu_with_devices (devices , plugin ):
556
+ def test_accelerator_gpu_with_devices (devices , plugin , num_devices , device_ids ):
556
557
557
558
trainer = Trainer (accelerator = "gpu" , devices = devices )
558
559
559
- assert trainer .gpus == devices
560
560
assert isinstance (trainer .strategy , plugin )
561
561
assert isinstance (trainer .accelerator , GPUAccelerator )
562
+ assert trainer .num_devices == num_devices
563
+ assert trainer .device_ids == device_ids
562
564
563
565
564
566
@RunIf (min_gpus = 1 )
565
567
def test_accelerator_auto_with_devices_gpu ():
566
568
trainer = Trainer (accelerator = "auto" , devices = 1 )
567
569
assert isinstance (trainer .accelerator , GPUAccelerator )
568
- assert trainer .gpus == 1
570
+ assert trainer .num_devices == 1
569
571
570
572
571
573
@RunIf (min_gpus = 1 )
@@ -576,7 +578,7 @@ def test_accelerator_gpu_with_gpus_priority():
576
578
with pytest .warns (UserWarning , match = "The flag `devices=4` will be ignored," ):
577
579
trainer = Trainer (accelerator = "gpu" , devices = 4 , gpus = gpus )
578
580
579
- assert trainer .gpus == gpus
581
+ assert trainer .num_devices == gpus
580
582
581
583
582
584
def test_validate_accelerator_and_devices ():
@@ -968,8 +970,8 @@ def test_devices_auto_choice_cpu(is_ipu_available_mock, is_tpu_available_mock, i
968
970
@mock .patch ("torch.cuda.device_count" , return_value = 2 )
969
971
def test_devices_auto_choice_gpu (is_gpu_available_mock , device_count_mock ):
970
972
trainer = Trainer (accelerator = "auto" , devices = "auto" )
973
+ assert isinstance (trainer .accelerator , GPUAccelerator )
971
974
assert trainer .num_devices == 2
972
- assert trainer .gpus == 2
973
975
974
976
975
977
@pytest .mark .parametrize (
0 commit comments