Skip to content

Commit 2ea22df

Browse files
committed
Add tests
1 parent 86abd61 commit 2ea22df

File tree

2 files changed

+41
-7
lines changed

2 files changed

+41
-7
lines changed

tests/accelerators/test_accelerator_connector.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,8 @@
4141
SLURMEnvironment,
4242
TorchElasticEnvironment,
4343
)
44-
from pytorch_lightning.utilities.exceptions import MisconfigurationException
4544
from pytorch_lightning.utilities import _GPU_AVAILABLE
45+
from pytorch_lightning.utilities.exceptions import MisconfigurationException
4646
from tests.helpers.boring_model import BoringModel
4747
from tests.helpers.runif import RunIf
4848

@@ -562,9 +562,9 @@ def test_accelerator_cpu():
562562

563563
with pytest.raises(MisconfigurationException, match="You passed `accelerator='gpu'`, but GPUs are not available"):
564564
trainer = Trainer(accelerator="gpu")
565-
566-
# with pytest.raises(MisconfigurationException, match="You passed `accelerator='gpu'`, but GPUs are not available"):
567-
# trainer = Trainer(accelerator="cpu", gpus=1)
565+
566+
with pytest.raises(MisconfigurationException, match="You requested GPUs:"):
567+
trainer = Trainer(accelerator="cpu", gpus=1)
568568

569569

570570
@RunIf(min_gpus=1)
@@ -575,7 +575,9 @@ def test_accelerator_gpu():
575575
assert trainer._device_type == "gpu"
576576
assert isinstance(trainer.accelerator, GPUAccelerator)
577577

578-
with pytest.raises(MisconfigurationException, match="You passed `accelerator='gpu'`, but you didn't pass `gpus` to `Trainer`"):
578+
with pytest.raises(
579+
MisconfigurationException, match="You passed `accelerator='gpu'`, but you didn't pass `gpus` to `Trainer`"
580+
):
579581
trainer = Trainer(accelerator="gpu")
580582

581583
trainer = Trainer(accelerator="auto", gpus=1)
@@ -584,9 +586,10 @@ def test_accelerator_gpu():
584586
assert isinstance(trainer.accelerator, GPUAccelerator)
585587

586588

589+
@RunIf(min_gpus=1)
587590
def test_accelerator_cpu_with_gpus_flag():
588-
591+
589592
trainer = Trainer(accelerator="cpu", gpus=1)
590593

591594
assert trainer._device_type == "cpu"
592-
assert isinstance(trainer.accelerator, GPUAccelerator)
595+
assert isinstance(trainer.accelerator, CPUAccelerator)

tests/accelerators/test_tpu_backend.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,9 @@
1616
from torch import nn
1717

1818
from pytorch_lightning import Trainer
19+
from pytorch_lightning.accelerators.cpu import CPUAccelerator
20+
from pytorch_lightning.accelerators.tpu import TPUAccelerator
21+
from pytorch_lightning.utilities.exceptions import MisconfigurationException
1922
from tests.helpers.boring_model import BoringModel
2023
from tests.helpers.runif import RunIf
2124
from tests.helpers.utils import pl_multi_process_test
@@ -113,3 +116,31 @@ def on_post_move_to_device(self):
113116

114117
with pytest.warns(UserWarning, match="The model layers do not match"):
115118
trainer.fit(model)
119+
120+
121+
@RunIf(tpu=True)
122+
def test_accelerator_tpu():
123+
124+
trainer = Trainer(accelerator="tpu", tpu_cores=8)
125+
126+
assert trainer._device_type == "tpu"
127+
assert isinstance(trainer.accelerator, TPUAccelerator)
128+
129+
with pytest.raises(
130+
MisconfigurationException, match="You passed `accelerator='tpu'`, but you didn't pass `tpu_cores` to `Trainer`"
131+
):
132+
trainer = Trainer(accelerator="tpu")
133+
134+
trainer = Trainer(accelerator="auto", tpu_cores=8)
135+
136+
assert trainer._device_type == "tpu"
137+
assert isinstance(trainer.accelerator, TPUAccelerator)
138+
139+
140+
@RunIf(tpu=True)
141+
def test_accelerator_cpu_with_tpu_cores_flag():
142+
143+
trainer = Trainer(accelerator="cpu", tpu_cores=8)
144+
145+
assert trainer._device_type == "cpu"
146+
assert isinstance(trainer.accelerator, CPUAccelerator)

0 commit comments

Comments
 (0)