@@ -904,8 +904,8 @@ def test_root_gpu_property(monkeypatch, gpus, expected_root_gpu, strategy):
904
904
monkeypatch .setattr (torch .cuda , "is_available" , lambda : True )
905
905
monkeypatch .setattr (torch .cuda , "device_count" , lambda : 16 )
906
906
with pytest .deprecated_call (
907
- match = "`Trainer.root_gpu` is deprecated in v1.6 and will be removed in v1.8. Please use "
908
- r" `Trainer.strategy.root_device.index if isinstance\(Trainer.accelerator, GPUAccelerator\) else None ` instead."
907
+ match = "`Trainer.root_gpu` is deprecated in v1.6 and will be removed in v1.8. "
908
+ "Please use `Trainer.strategy.root_device.index` instead."
909
909
):
910
910
assert Trainer (gpus = gpus , strategy = strategy ).root_gpu == expected_root_gpu
911
911
@@ -921,7 +921,7 @@ def test_root_gpu_property(monkeypatch, gpus, expected_root_gpu, strategy):
921
921
def test_root_gpu_property_0_passing (monkeypatch , gpus , expected_root_gpu , strategy ):
922
922
monkeypatch .setattr (torch .cuda , "device_count" , lambda : 0 )
923
923
with pytest .deprecated_call (
924
- match = "`Trainer.root_gpu` is deprecated in v1.6 and will be removed in v1.8. Please use "
925
- r" `Trainer.strategy.root_device.index if isinstance\(Trainer.accelerator, GPUAccelerator\) else None ` instead."
924
+ match = "`Trainer.root_gpu` is deprecated in v1.6 and will be removed in v1.8. "
925
+ "Please use `Trainer.strategy.root_device.index` instead."
926
926
):
927
927
assert Trainer (gpus = gpus , strategy = strategy ).root_gpu == expected_root_gpu
0 commit comments