Skip to content

Commit a442852

Browse files
committed
address comments
1 parent cd12345 commit a442852

File tree

6 files changed

+21
-15
lines changed

6 files changed

+21
-15
lines changed

pytorch_lightning/accelerators/gpu.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,8 @@ def auto_device_count() -> int:
8282

8383
@staticmethod
8484
def is_available() -> bool:
85-
return torch.cuda.device_count() > 0
85+
print(torch.cuda.is_available() and torch.cuda.device_count() > 0)
86+
return torch.cuda.is_available() and torch.cuda.device_count() > 0
8687

8788

8889
def get_nvidia_gpu_stats(device: _DEVICE) -> dict[str, float]:

pytorch_lightning/strategies/tpu_spawn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@
5252
class TPUSpawnStrategy(DDPSpawnStrategy):
5353
"""Strategy for training multiple TPU devices using the :func:`torch.multiprocessing.spawn` method."""
5454

55-
strategy_name = "tpu_spawn_strategy"
55+
strategy_name = "tpu_spawn"
5656

5757
def __init__(
5858
self,

pytorch_lightning/trainer/connectors/accelerator_connector.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -451,6 +451,8 @@ def _set_parallel_devices_and_init_accelerator(self) -> None:
451451
elif self._accelerator_flag == "gpu":
452452
self.accelerator = GPUAccelerator()
453453
self._set_devices_flag_if_auto_passed()
454+
# TODO add device availablity check for all devices, not only GPU
455+
self._check_device_availability()
454456
if isinstance(self._devices_flag, int) or isinstance(self._devices_flag, str):
455457
self._devices_flag = int(self._devices_flag)
456458
self._parallel_devices = (
@@ -459,7 +461,7 @@ def _set_parallel_devices_and_init_accelerator(self) -> None:
459461
else []
460462
)
461463
else:
462-
self._parallel_devices = [torch.device("cuda", i) for i in self._devices_flag]
464+
self._parallel_devices = [torch.device("cuda", i) for i in self._devices_flag] # type: ignore
463465

464466
elif self._accelerator_flag == "cpu":
465467
self.accelerator = CPUAccelerator()
@@ -479,6 +481,12 @@ def _set_devices_flag_if_auto_passed(self) -> None:
479481
if self._devices_flag == "auto" or not self._devices_flag:
480482
self._devices_flag = self.accelerator.auto_device_count()
481483

484+
def _check_device_availability(self) -> None:
485+
if not self.accelerator.is_available():
486+
raise MisconfigurationException(
487+
f"You requested {self._accelerator_flag}, " f"but {self._accelerator_flag} is not available"
488+
)
489+
482490
def _choose_and_init_cluster_environment(self) -> ClusterEnvironment:
483491
if isinstance(self._cluster_environment_flag, ClusterEnvironment):
484492
return self._cluster_environment_flag
@@ -514,7 +522,7 @@ def _choose_strategy(self) -> Union[Strategy, str]:
514522
return DDPStrategy.strategy_name
515523
if len(self._parallel_devices) <= 1:
516524
device = (
517-
device_parser.determine_root_gpu_device(self._parallel_devices)
525+
device_parser.determine_root_gpu_device(self._parallel_devices) # type: ignore
518526
if self._accelerator_flag == "gpu"
519527
else "cpu"
520528
)

pytorch_lightning/utilities/device_parser.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from pytorch_lightning.tuner.auto_gpu_select import pick_multiple_gpus
2020
from pytorch_lightning.utilities import _TPU_AVAILABLE
2121
from pytorch_lightning.utilities.exceptions import MisconfigurationException
22+
from pytorch_lightning.utilities.types import _DEVICE
2223

2324

2425
def determine_root_gpu_device(gpus: List[_DEVICE]) -> Optional[_DEVICE]:
@@ -164,7 +165,7 @@ def _sanitize_gpu_ids(gpus: List[int]) -> List[int]:
164165
for gpu in gpus:
165166
if gpu not in all_available_gpus:
166167
raise MisconfigurationException(
167-
f"You requested GPUs: {gpus}\n But your machine only has: {all_available_gpus}"
168+
f"You requested gpu: {gpus}\n But your machine only has: {all_available_gpus}"
168169
)
169170
return gpus
170171

tests/accelerators/test_accelerator_connector.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -453,12 +453,11 @@ def test_accelerator_cpu(mack_gpu_avalible):
453453
assert trainer._device_type == "cpu"
454454
assert isinstance(trainer.accelerator, CPUAccelerator)
455455

456-
with pytest.raises(MisconfigurationException):
456+
with pytest.raises(MisconfigurationException, match="You requested gpu"):
457457
trainer = Trainer(gpus=1)
458-
# with pytest.raises(MisconfigurationException):
459-
# trainer = Trainer(accelerator="gpu")
460-
461-
with pytest.raises(MisconfigurationException, match="You requested GPUs:"):
458+
with pytest.raises(MisconfigurationException, match="You requested gpu, but gpu is not available"):
459+
trainer = Trainer(accelerator="gpu")
460+
with pytest.raises(MisconfigurationException, match="You requested gpu:"):
462461
trainer = Trainer(accelerator="cpu", gpus=1)
463462

464463

@@ -470,9 +469,6 @@ def test_accelerator_gpu():
470469
assert trainer._device_type == "gpu"
471470
assert isinstance(trainer.accelerator, GPUAccelerator)
472471

473-
# with pytest.raises(
474-
# MisconfigurationException, match="You passed `accelerator='gpu'`, but you didn't pass `gpus` to `Trainer`"
475-
# ):
476472
trainer = Trainer(accelerator="gpu")
477473

478474
trainer = Trainer(accelerator="auto", gpus=1)

tests/utilities/test_cli.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -582,8 +582,8 @@ def on_fit_start(self):
582582
(
583583
# dict(strategy="ddp_spawn")
584584
# dict(strategy="ddp")
585-
# !! old accl_conn will choose singleDeviceStrategy for both strategy=ddp/ddp_spawn
586-
# this test never worked with DDPSpawnStrategy
585+
# the previous accl_conn will choose singleDeviceStrategy for both strategy=ddp/ddp_spawn
586+
# TODO revisit this test as it never worked with DDP or DDPSpawn
587587
dict(strategy="single_device"),
588588
pytest.param({"tpu_cores": 1}, marks=RunIf(tpu=True)),
589589
),

0 commit comments

Comments
 (0)