Skip to content

Commit ead59b3

Browse files
committed
address comments for testing
1 parent 4fba371 commit ead59b3

File tree

3 files changed

+51
-14
lines changed

3 files changed

+51
-14
lines changed

tests/accelerators/test_tpu.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -310,3 +310,20 @@ def test_mp_device_dataloader_attribute(_):
310310
def test_warning_if_tpus_not_used():
311311
with pytest.warns(UserWarning, match="TPU available but not used. Set `accelerator` and `devices`"):
312312
Trainer()
313+
314+
315+
@RunIf(tpu=True)
316+
@pytest.mark.parametrize(
317+
["trainer_kwargs", "expected_device_ids"],
318+
[
319+
({"accelerator": "tpu", "devices": 1}, [0]),
320+
({"accelerator": "tpu", "devices": 8}, list(range(8))),
321+
({"accelerator": "tpu", "devices": "8"}, list(range(8))),
322+
({"accelerator": "tpu", "devices": [2]}, [2]),
323+
({"accelerator": "tpu", "devices": "2,"}, [2]),
324+
],
325+
)
326+
def test_trainer_config_device_ids(monkeypatch, trainer_kwargs, expected_device_ids):
327+
trainer = Trainer(**trainer_kwargs)
328+
assert trainer.device_ids == expected_device_ids
329+
assert trainer.num_devices == len(expected_device_ids)

tests/deprecated_api/test_remove_1-8.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import torch
2121
from torch import optim
2222

23+
import pytorch_lightning
2324
from pytorch_lightning import Callback, Trainer
2425
from pytorch_lightning.loggers import CSVLogger, LightningLoggerBase
2526
from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin
@@ -695,19 +696,28 @@ def on_load_checkpoint(self, checkpoint):
695696
[
696697
({}, 1),
697698
({"devices": 1}, 1),
699+
({"devices": 1}, 1),
700+
({"devices": "1"}, 1),
701+
({"devices": 2}, 2),
698702
({"accelerator": "gpu", "devices": 1}, 1),
699-
({"strategy": "ddp", "devices": 1}, 1),
700-
({"strategy": "ddp", "accelerator": "gpu", "devices": 1}, 1),
701-
({"strategy": "ddp", "devices": 2}, 2),
702-
({"strategy": "ddp", "accelerator": "gpu", "devices": 2}, 2),
703-
({"strategy": "ddp", "accelerator": "gpu", "devices": [2]}, 1),
704-
({"strategy": "ddp", "accelerator": "gpu", "devices": [0, 2]}, 2),
703+
({"accelerator": "gpu", "devices": 2}, 2),
704+
({"accelerator": "gpu", "devices": "2"}, 2),
705+
({"accelerator": "gpu", "devices": [2]}, 1),
706+
({"accelerator": "gpu", "devices": "2,"}, 1),
707+
({"accelerator": "gpu", "devices": [0, 2]}, 2),
708+
({"accelerator": "gpu", "devices": "0, 2"}, 2),
709+
({"accelerator": "ipu", "devices": 1}, 1),
710+
({"accelerator": "ipu", "devices": 2}, 2),
705711
],
706712
)
707-
def test_v1_8_0_trainer_devices(monkeypatch, trainer_kwargs, expected_devices):
713+
def test_trainer_config_device_ids(monkeypatch, trainer_kwargs, expected_devices):
708714
if trainer_kwargs.get("accelerator") == "gpu":
709715
monkeypatch.setattr(torch.cuda, "is_available", lambda: True)
710716
monkeypatch.setattr(torch.cuda, "device_count", lambda: 4)
717+
elif trainer_kwargs.get("accelerator") == "ipu":
718+
monkeypatch.setattr(pytorch_lightning.accelerators.ipu.IPUAccelerator, "is_available", lambda _: True)
719+
monkeypatch.setattr(pytorch_lightning.strategies.ipu, "_IPU_AVAILABLE", lambda: True)
720+
711721
trainer = Trainer(**trainer_kwargs)
712722
with pytest.deprecated_call(
713723
match="`Trainer.devices` was deprecated in v1.6 and will be removed in v1.8."

tests/trainer/test_trainer.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from torch.optim import SGD
3131
from torch.utils.data import DataLoader, IterableDataset
3232

33+
import pytorch_lightning
3334
import tests.helpers.utils as tutils
3435
from pytorch_lightning import Callback, LightningDataModule, LightningModule, Trainer
3536
from pytorch_lightning.accelerators import CPUAccelerator, GPUAccelerator
@@ -2153,19 +2154,28 @@ def test_dataloaders_are_not_loaded_if_disabled_through_limit_batches(running_st
21532154
[
21542155
({}, [0]),
21552156
({"devices": 1}, [0]),
2157+
({"devices": 1}, [0]),
2158+
({"devices": "1"}, [0]),
2159+
({"devices": 2}, [0, 1]),
21562160
({"accelerator": "gpu", "devices": 1}, [0]),
2157-
({"strategy": "ddp", "devices": 1}, [0]),
2158-
({"strategy": "ddp", "accelerator": "gpu", "devices": 1}, [0]),
2159-
({"strategy": "ddp", "devices": 2}, [0, 1]),
2160-
({"strategy": "ddp", "accelerator": "gpu", "devices": 2}, [0, 1]),
2161-
({"strategy": "ddp", "accelerator": "gpu", "devices": [2]}, [2]),
2162-
({"strategy": "ddp", "accelerator": "gpu", "devices": [0, 2]}, [0, 2]),
2161+
({"accelerator": "gpu", "devices": 2}, [0, 1]),
2162+
({"accelerator": "gpu", "devices": "2"}, [0, 1]),
2163+
({"accelerator": "gpu", "devices": [2]}, [2]),
2164+
({"accelerator": "gpu", "devices": "2,"}, [2]),
2165+
({"accelerator": "gpu", "devices": [0, 2]}, [0, 2]),
2166+
({"accelerator": "gpu", "devices": "0, 2"}, [0, 2]),
2167+
({"accelerator": "ipu", "devices": 1}, [0]),
2168+
({"accelerator": "ipu", "devices": 2}, [0, 1]),
21632169
],
21642170
)
21652171
def test_trainer_config_device_ids(monkeypatch, trainer_kwargs, expected_device_ids):
21662172
if trainer_kwargs.get("accelerator") == "gpu":
21672173
monkeypatch.setattr(torch.cuda, "is_available", lambda: True)
21682174
monkeypatch.setattr(torch.cuda, "device_count", lambda: 4)
2175+
elif trainer_kwargs.get("accelerator") == "ipu":
2176+
monkeypatch.setattr(pytorch_lightning.accelerators.ipu.IPUAccelerator, "is_available", lambda _: True)
2177+
monkeypatch.setattr(pytorch_lightning.strategies.ipu, "_IPU_AVAILABLE", lambda: True)
2178+
21692179
trainer = Trainer(**trainer_kwargs)
21702180
assert trainer.device_ids == expected_device_ids
2171-
assert len(trainer.device_ids) == trainer.num_devices
2181+
assert trainer.num_devices == len(expected_device_ids)

0 commit comments

Comments
 (0)