Skip to content

Attempt to query device count via NVML #14631

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 19 commits into from
Sep 22, 2022
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 70 additions & 11 deletions src/lightning_lite/utilities/device_parser.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import multiprocessing
from typing import Any, List, MutableSequence, Optional, Tuple, Union
import os
import warnings
from functools import lru_cache
from typing import Any, List, MutableSequence, Optional, Set, Tuple, Union

import torch

from lightning_lite.plugins.environments.torchelastic_environment import TorchElasticEnvironment
from lightning_lite.strategies.launchers.multiprocessing import _is_forking_disabled
from lightning_lite.utilities.exceptions import MisconfigurationException
from lightning_lite.utilities.imports import _TORCH_GREATER_EQUAL_1_13
from lightning_lite.utilities.types import _DEVICE


Expand Down Expand Up @@ -284,16 +286,20 @@ def _parse_tpu_cores_str(tpu_cores: str) -> Union[int, List[int]]:
return [int(x.strip()) for x in tpu_cores.split(",") if len(x) > 0]


@lru_cache(1)
def num_cuda_devices() -> int:
"""Returns the number of GPUs available.
"""Returns the number of available CUDA devices.

Unlike :func:`torch.cuda.device_count`, this function does its best not to create a CUDA context for fork support,
if the platform allows it.
"""
if "fork" not in torch.multiprocessing.get_all_start_methods() or _is_forking_disabled():
if _TORCH_GREATER_EQUAL_1_13:
return torch.cuda.device_count()
with multiprocessing.get_context("fork").Pool(1) as pool:
return pool.apply(torch.cuda.device_count)

# Implementation copied from upstream: https://github.com/pytorch/pytorch/pull/84879
# TODO: Remove once minimum supported PyTorch version is 1.13
nvml_count = _device_count_nvml()
return torch.cuda.device_count() if nvml_count < 0 else nvml_count


def is_cuda_available() -> bool:
Expand All @@ -302,7 +308,60 @@ def is_cuda_available() -> bool:
Unlike :func:`torch.cuda.is_available`, this function does its best not to create a CUDA context for fork support,
if the platform allows it.
"""
if "fork" not in torch.multiprocessing.get_all_start_methods() or _is_forking_disabled():
return torch.cuda.is_available()
with multiprocessing.get_context("fork").Pool(1) as pool:
return pool.apply(torch.cuda.is_available)
return num_cuda_devices() > 0


def _parse_visible_devices() -> Set[int]:
"""Implementation copied from upstream: https://github.com/pytorch/pytorch/pull/84879."""
var = os.getenv("CUDA_VISIBLE_DEVICES")
if var is None:
return {x for x in range(64)}

def _strtoul(s: str) -> int:
"""Return -1 or integer sequence string starts with."""
if len(s) == 0:
return -1
for idx, c in enumerate(s):
if not c.isdigit():
break
if idx + 1 == len(s):
idx += 1
return int(s[:idx]) if idx > 0 else -1

# CUDA_VISIBLE_DEVICES uses something like strtoul
# which makes `1gpu2,2ampere` is equivalent to `1,2`
rc: Set[int] = set()
for elem in var.split(","):
rc.add(_strtoul(elem.strip()))
return rc


def _raw_device_count_nvml() -> int:
"""Implementation copied from upstream: https://github.com/pytorch/pytorch/pull/84879."""
from ctypes import c_int, CDLL

nvml_h = CDLL("libnvidia-ml.so.1")
rc = nvml_h.nvmlInit()
if rc != 0:
warnings.warn("Can't initialize NVML")
return -1
dev_arr = (c_int * 1)(-1)
rc = nvml_h.nvmlDeviceGetCount_v2(dev_arr)
if rc != 0:
warnings.warn("Can't get nvml device count")
return -1
del nvml_h
return dev_arr[0]


def _device_count_nvml() -> int:
"""Implementation copied from upstream: https://github.com/pytorch/pytorch/pull/84879."""
try:
raw_cnt = _raw_device_count_nvml()
if raw_cnt <= 0:
return raw_cnt
return len(set(range(raw_cnt)).intersection(_parse_visible_devices()))
except OSError:
return -1
except AttributeError:
return -1
Original file line number Diff line number Diff line change
Expand Up @@ -68,10 +68,6 @@ def __init__(
f"The start method '{self._start_method}' is not available on this platform. Available methods are:"
f" {', '.join(mp.get_all_start_methods())}"
)
if start_method in ("fork", "forkserver") and _is_forking_disabled():
raise ValueError(
"Forking is disabled in this environment by `PL_DISABLE_FORKING=1`. Choose a different start method."
)

@property
def is_interactive_compatible(self) -> bool:
Expand Down Expand Up @@ -287,8 +283,3 @@ def restore(self) -> None:
torch.use_deterministic_algorithms(self.use_deterministic_algorithms)
torch.backends.cudnn.benchmark = self.cudnn_benchmark
_set_rng_states(self.rng_states)


def _is_forking_disabled() -> bool:
"""Returns whether forking is disabled through the environment variable ``PL_DISABLE_FORK``."""
return bool(int(os.environ.get("PL_DISABLE_FORK", "0")))
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,6 @@
TPUSpawnStrategy,
)
from pytorch_lightning.strategies.ddp_spawn import _DDP_FORK_ALIASES
from pytorch_lightning.strategies.launchers.multiprocessing import _is_forking_disabled
from pytorch_lightning.tuner.auto_gpu_select import pick_multiple_gpus
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import (
Expand Down Expand Up @@ -638,10 +637,6 @@ def _check_strategy_and_fallback(self) -> None:
f"You selected `Trainer(strategy='{strategy_flag}')` but process forking is not supported on this"
f" platform. We recommed `Trainer(strategy='ddp_spawn')` instead."
)
if strategy_flag in _DDP_FORK_ALIASES and _is_forking_disabled():
raise ValueError(
"Forking is disabled in this environment by `PL_DISABLE_FORKING=1`. Choose a different strategy."
)
if strategy_flag:
self._strategy_flag = strategy_flag

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from unittest import mock
from unittest.mock import ANY, Mock

import pytest
import torch

from pytorch_lightning.strategies.launchers.multiprocessing import _GlobalStateSnapshot, _MultiProcessingLauncher
from tests_pytorch.helpers.runif import RunIf


@mock.patch("pytorch_lightning.strategies.launchers.multiprocessing.mp.get_all_start_methods", return_value=[])
Expand All @@ -28,14 +26,6 @@ def test_multiprocessing_launcher_forking_on_unsupported_platform(_):
_MultiProcessingLauncher(strategy=Mock(), start_method="fork")


@RunIf(skip_windows=True)
@pytest.mark.parametrize("start_method", ["fork", "forkserver"])
@mock.patch.dict(os.environ, {"PL_DISABLE_FORK": "1"}, clear=True)
def test_multiprocessing_launcher_disabled_forking(start_method):
with pytest.raises(ValueError, match="Forking is disabled in this environment"):
_MultiProcessingLauncher(strategy=Mock(), start_method=start_method)


@pytest.mark.parametrize("start_method", ["spawn", "fork"])
@mock.patch("pytorch_lightning.strategies.launchers.multiprocessing.mp")
def test_multiprocessing_launcher_start_method(mp_mock, start_method):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -821,12 +821,3 @@ def test_accelerator_specific_checkpoint_io(*_):
def test_ddp_fork_on_unsupported_platform(_, strategy):
with pytest.raises(ValueError, match="process forking is not supported on this platform"):
Trainer(strategy=strategy)


@RunIf(skip_windows=True)
@pytest.mark.parametrize("strategy", _DDP_FORK_ALIASES)
@mock.patch.dict(os.environ, {"PL_DISABLE_FORK": "1"}, clear=True)
def test_strategy_choice_ddp_spawn_in_interactive_when_fork_disabled(strategy):
"""Test there is an error when forking is disabled via the environment variable and the user requests fork."""
with pytest.raises(ValueError, match="Forking is disabled in this environment"):
Trainer(devices=2, strategy=strategy)