Skip to content

Commit f974bc6

Browse files
committed
improve tests and docstring
1 parent 6b27c1c commit f974bc6

File tree

2 files changed

+9
-9
lines changed

2 files changed

+9
-9
lines changed

pytorch_lightning/trainer/trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2024,7 +2024,7 @@ def device_ids(self) -> List[int]:
20242024

20252025
@property
20262026
def num_devices(self) -> int:
2027-
"""Number of devices per node."""
2027+
"""Number of devices the trainer uses per node."""
20282028
return len(self.strategy.parallel_devices) if isinstance(self.strategy, ParallelStrategy) else 1
20292029

20302030
@property

tests/accelerators/test_tpu.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -314,16 +314,16 @@ def test_warning_if_tpus_not_used():
314314

315315
@RunIf(tpu=True)
316316
@pytest.mark.parametrize(
317-
["trainer_kwargs", "expected_device_ids"],
317+
["devices", "expected_device_ids"],
318318
[
319-
({"accelerator": "tpu", "devices": 1}, [0]),
320-
({"accelerator": "tpu", "devices": 8}, list(range(8))),
321-
({"accelerator": "tpu", "devices": "8"}, list(range(8))),
322-
({"accelerator": "tpu", "devices": [2]}, [2]),
323-
({"accelerator": "tpu", "devices": "2,"}, [2]),
319+
(1, [0]),
320+
(8, list(range(8))),
321+
("8", list(range(8))),
322+
([2], [2]),
323+
("2,", [2]),
324324
],
325325
)
326-
def test_trainer_config_device_ids(monkeypatch, trainer_kwargs, expected_device_ids):
327-
trainer = Trainer(**trainer_kwargs)
326+
def test_trainer_config_device_ids(devices, expected_device_ids):
327+
trainer = Trainer(accelerator="tpu", devices=devices)
328328
assert trainer.device_ids == expected_device_ids
329329
assert trainer.num_devices == len(expected_device_ids)

0 commit comments

Comments
 (0)