@@ -895,8 +895,8 @@ def test_root_gpu_property(monkeypatch, gpus, expected_root_gpu, strategy):
895
895
monkeypatch .setattr (torch .cuda , "is_available" , lambda : True )
896
896
monkeypatch .setattr (torch .cuda , "device_count" , lambda : 16 )
897
897
with pytest .deprecated_call (
898
- match = "`Trainer.root_gpu` is deprecated in v1.6 and will be removed in v1.8. Please use "
899
- r" `Trainer.strategy.root_device.index if isinstance\(Trainer.accelerator, GPUAccelerator\) else None ` instead."
898
+ match = "`Trainer.root_gpu` is deprecated in v1.6 and will be removed in v1.8. "
899
+ "Please use `Trainer.strategy.root_device.index` instead."
900
900
):
901
901
assert Trainer (gpus = gpus , strategy = strategy ).root_gpu == expected_root_gpu
902
902
@@ -912,7 +912,7 @@ def test_root_gpu_property(monkeypatch, gpus, expected_root_gpu, strategy):
912
912
def test_root_gpu_property_0_passing (monkeypatch , gpus , expected_root_gpu , strategy ):
913
913
monkeypatch .setattr (torch .cuda , "device_count" , lambda : 0 )
914
914
with pytest .deprecated_call (
915
- match = "`Trainer.root_gpu` is deprecated in v1.6 and will be removed in v1.8. Please use "
916
- r" `Trainer.strategy.root_device.index if isinstance\(Trainer.accelerator, GPUAccelerator\) else None ` instead."
915
+ match = "`Trainer.root_gpu` is deprecated in v1.6 and will be removed in v1.8. "
916
+ "Please use `Trainer.strategy.root_device.index` instead."
917
917
):
918
918
assert Trainer (gpus = gpus , strategy = strategy ).root_gpu == expected_root_gpu
0 commit comments