Skip to content

Attempt to query device count via NVML #14631

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 19 commits into from
Sep 22, 2022
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 70 additions & 11 deletions src/lightning_lite/accelerators/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,15 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import multiprocessing
from typing import Dict, List, Optional, Union
import os
import warnings
from functools import lru_cache
from typing import Dict, List, Optional, Set, Union

import torch

from lightning_lite.accelerators.accelerator import Accelerator
from lightning_lite.strategies.launchers.multiprocessing import _is_forking_disabled
from lightning_lite.utilities.imports import _TORCH_GREATER_EQUAL_1_13


class CUDAAccelerator(Accelerator):
Expand Down Expand Up @@ -75,16 +77,20 @@ def _get_all_available_cuda_gpus() -> List[int]:
return list(range(num_cuda_devices()))


@lru_cache(1)
def num_cuda_devices() -> int:
"""Returns the number of GPUs available.
"""Returns the number of available CUDA devices.

Unlike :func:`torch.cuda.device_count`, this function does its best not to create a CUDA context for fork support,
if the platform allows it.
"""
if "fork" not in torch.multiprocessing.get_all_start_methods() or _is_forking_disabled():
if _TORCH_GREATER_EQUAL_1_13:
return torch.cuda.device_count()
with multiprocessing.get_context("fork").Pool(1) as pool:
return pool.apply(torch.cuda.device_count)

# Implementation copied from upstream: https://github.com/pytorch/pytorch/pull/84879
# TODO: Remove once minimum supported PyTorch version is 1.13
nvml_count = _device_count_nvml()
return torch.cuda.device_count() if nvml_count < 0 else nvml_count


def is_cuda_available() -> bool:
Expand All @@ -93,7 +99,60 @@ def is_cuda_available() -> bool:
Unlike :func:`torch.cuda.is_available`, this function does its best not to create a CUDA context for fork support,
if the platform allows it.
"""
if "fork" not in torch.multiprocessing.get_all_start_methods() or _is_forking_disabled():
return torch.cuda.is_available()
with multiprocessing.get_context("fork").Pool(1) as pool:
return pool.apply(torch.cuda.is_available)
return num_cuda_devices() > 0


def _parse_visible_devices() -> Set[int]:
"""Implementation copied from upstream: https://github.com/pytorch/pytorch/pull/84879."""
var = os.getenv("CUDA_VISIBLE_DEVICES")
if var is None:
return {x for x in range(64)}

def _strtoul(s: str) -> int:
"""Return -1 or integer sequence string starts with."""
if len(s) == 0:
return -1
for idx, c in enumerate(s):
if not c.isdigit():
break
if idx + 1 == len(s):
idx += 1
return int(s[:idx]) if idx > 0 else -1

# CUDA_VISIBLE_DEVICES uses something like strtoul
# which makes `1gpu2,2ampere` is equivalent to `1,2`
rc: Set[int] = set()
for elem in var.split(","):
rc.add(_strtoul(elem.strip()))
return rc


def _raw_device_count_nvml() -> int:
"""Implementation copied from upstream: https://github.com/pytorch/pytorch/pull/84879."""
from ctypes import c_int, CDLL

nvml_h = CDLL("libnvidia-ml.so.1")
rc = nvml_h.nvmlInit()
if rc != 0:
warnings.warn("Can't initialize NVML")
return -1
dev_arr = (c_int * 1)(-1)
rc = nvml_h.nvmlDeviceGetCount_v2(dev_arr)
if rc != 0:
warnings.warn("Can't get nvml device count")
return -1
del nvml_h
return dev_arr[0]


def _device_count_nvml() -> int:
"""Implementation copied from upstream: https://github.com/pytorch/pytorch/pull/84879."""
try:
raw_cnt = _raw_device_count_nvml()
if raw_cnt <= 0:
return raw_cnt
return len(set(range(raw_cnt)).intersection(_parse_visible_devices()))
except OSError:
return -1
except AttributeError:
return -1
22 changes: 22 additions & 0 deletions src/lightning_lite/utilities/device_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,3 +187,25 @@ def _check_data_type(device_ids: Any) -> None:
raise MisconfigurationException(f"{msg} a sequence of {type(id_).__name__}.")
elif type(device_ids) not in (int, str):
raise MisconfigurationException(f"{msg} {type(device_ids).__name__}.")


def _tpu_cores_valid(tpu_cores: Any) -> bool:
# allow 1 or 8 cores
if tpu_cores in (1, 8, None):
return True

# allow picking 1 of 8 indexes
if isinstance(tpu_cores, (list, tuple, set)):
has_1_tpu_idx = len(tpu_cores) == 1
is_valid_tpu_idx = 1 <= list(tpu_cores)[0] <= 8

is_valid_tpu_core_choice = has_1_tpu_idx and is_valid_tpu_idx
return is_valid_tpu_core_choice

return False


def _parse_tpu_cores_str(tpu_cores: str) -> Union[int, List[int]]:
if tpu_cores in ("1", "8"):
return int(tpu_cores)
return [int(x.strip()) for x in tpu_cores.split(",") if len(x) > 0]
Original file line number Diff line number Diff line change
Expand Up @@ -68,10 +68,6 @@ def __init__(
f"The start method '{self._start_method}' is not available on this platform. Available methods are:"
f" {', '.join(mp.get_all_start_methods())}"
)
if start_method in ("fork", "forkserver") and _is_forking_disabled():
raise ValueError(
"Forking is disabled in this environment by `PL_DISABLE_FORKING=1`. Choose a different start method."
)

@property
def is_interactive_compatible(self) -> bool:
Expand Down Expand Up @@ -287,8 +283,3 @@ def restore(self) -> None:
torch.use_deterministic_algorithms(self.use_deterministic_algorithms)
torch.backends.cudnn.benchmark = self.cudnn_benchmark
_set_rng_states(self.rng_states)


def _is_forking_disabled() -> bool:
"""Returns whether forking is disabled through the environment variable ``PL_DISABLE_FORK``."""
return bool(int(os.environ.get("PL_DISABLE_FORK", "0")))
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,6 @@
TPUSpawnStrategy,
)
from pytorch_lightning.strategies.ddp_spawn import _DDP_FORK_ALIASES
from pytorch_lightning.strategies.launchers.multiprocessing import _is_forking_disabled
from pytorch_lightning.tuner.auto_gpu_select import pick_multiple_gpus
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import (
Expand Down Expand Up @@ -632,10 +631,6 @@ def _check_strategy_and_fallback(self) -> None:
f"You selected `Trainer(strategy='{strategy_flag}')` but process forking is not supported on this"
f" platform. We recommed `Trainer(strategy='ddp_spawn')` instead."
)
if strategy_flag in _DDP_FORK_ALIASES and _is_forking_disabled():
raise ValueError(
"Forking is disabled in this environment by `PL_DISABLE_FORKING=1`. Choose a different strategy."
)
if strategy_flag:
self._strategy_flag = strategy_flag

Expand Down
9 changes: 8 additions & 1 deletion tests/tests_lite/accelerators/test_cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import pytest
import torch

from lightning_lite.accelerators.cpu import CPUAccelerator
from lightning_lite.accelerators.cpu import CPUAccelerator, parse_cpu_cores


def test_auto_device_count():
Expand All @@ -41,3 +41,10 @@ def test_init_device_with_wrong_device_type():
)
def test_get_parallel_devices(devices, expected):
assert CPUAccelerator.get_parallel_devices(devices) == expected


@pytest.mark.parametrize("devices", ([3], -1))
def test_invalid_devices_with_cpu_accelerator(devices):
"""Test invalid device flag raises MisconfigurationException."""
with pytest.raises(TypeError, match="should be an int > 0"):
parse_cpu_cores(devices)
11 changes: 10 additions & 1 deletion tests/tests_lite/accelerators/test_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import torch
from tests_lite.helpers.runif import RunIf

from lightning_lite.accelerators.cuda import CUDAAccelerator
from lightning_lite.accelerators.cuda import CUDAAccelerator, is_cuda_available, num_cuda_devices


@mock.patch("lightning_lite.accelerators.cuda.num_cuda_devices", return_value=2)
Expand Down Expand Up @@ -51,3 +51,12 @@ def test_get_parallel_devices(devices, expected):
def test_set_cuda_device(set_device_mock):
CUDAAccelerator().setup_device(torch.device("cuda", 1))
set_device_mock.assert_called_once_with(torch.device("cuda", 1))


@mock.patch("lightning_lite.accelerators.cuda._device_count_nvml", return_value=-1)
@mock.patch("torch.cuda.device_count", return_value=100)
def test_num_cuda_devices_without_forking(*_):
"""Test that if NVML can't be loaded, our helper functions fall back to the default implementation for
determining CUDA availability."""
assert is_cuda_available()
assert num_cuda_devices() == 100
22 changes: 0 additions & 22 deletions tests/tests_lite/utilities/test_device_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,7 @@
from unittest import mock

import pytest
import torch

from lightning_lite.accelerators.cpu import parse_cpu_cores
from lightning_lite.accelerators.cuda import is_cuda_available, num_cuda_devices
from lightning_lite.utilities import device_parser
from lightning_lite.utilities.exceptions import MisconfigurationException

Expand Down Expand Up @@ -87,22 +84,3 @@ def test_parse_gpu_fail_on_non_existent_id_2(_):
def test_parse_gpu_returns_none_when_no_devices_are_available(_, devices):
with pytest.raises(MisconfigurationException):
device_parser.parse_gpu_ids(devices, include_cuda=True)


@pytest.mark.skipif(
"fork" in torch.multiprocessing.get_all_start_methods(), reason="Requires platform without forking support"
)
@mock.patch("torch.cuda.is_available", return_value=True)
@mock.patch("torch.cuda.device_count", return_value=2)
def test_num_cuda_devices_without_forking(*_):
"""This merely tests that on platforms without fork support our helper functions fall back to the default
implementation for determining cuda availability."""
assert is_cuda_available()
assert num_cuda_devices() == 2


@pytest.mark.parametrize("devices", ([3], -1))
def test_invalid_devices_with_cpu_accelerator(devices):
"""Test invalid device flag raises MisconfigurationException."""
with pytest.raises(TypeError, match="should be an int > 0"):
parse_cpu_cores(devices)
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from unittest import mock
from unittest.mock import ANY, Mock

import pytest
import torch

from pytorch_lightning.strategies.launchers.multiprocessing import _GlobalStateSnapshot, _MultiProcessingLauncher
from tests_pytorch.helpers.runif import RunIf


@mock.patch("pytorch_lightning.strategies.launchers.multiprocessing.mp.get_all_start_methods", return_value=[])
Expand All @@ -28,14 +26,6 @@ def test_multiprocessing_launcher_forking_on_unsupported_platform(_):
_MultiProcessingLauncher(strategy=Mock(), start_method="fork")


@RunIf(skip_windows=True)
@pytest.mark.parametrize("start_method", ["fork", "forkserver"])
@mock.patch.dict(os.environ, {"PL_DISABLE_FORK": "1"}, clear=True)
def test_multiprocessing_launcher_disabled_forking(start_method):
with pytest.raises(ValueError, match="Forking is disabled in this environment"):
_MultiProcessingLauncher(strategy=Mock(), start_method=start_method)


@pytest.mark.parametrize("start_method", ["spawn", "fork"])
@mock.patch("pytorch_lightning.strategies.launchers.multiprocessing.mp")
def test_multiprocessing_launcher_start_method(mp_mock, start_method):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -794,12 +794,3 @@ def test_accelerator_specific_checkpoint_io(*_):
def test_ddp_fork_on_unsupported_platform(_, strategy):
with pytest.raises(ValueError, match="process forking is not supported on this platform"):
Trainer(strategy=strategy)


@RunIf(skip_windows=True)
@pytest.mark.parametrize("strategy", _DDP_FORK_ALIASES)
@mock.patch.dict(os.environ, {"PL_DISABLE_FORK": "1"}, clear=True)
def test_strategy_choice_ddp_spawn_in_interactive_when_fork_disabled(strategy):
"""Test there is an error when forking is disabled via the environment variable and the user requests fork."""
with pytest.raises(ValueError, match="Forking is disabled in this environment"):
Trainer(devices=2, strategy=strategy)