Skip to content

Commit bf2d87f

Browse files
awaelchlicarmocca
authored andcommitted
Attempt to query device count via NVML (#14631)
Co-authored-by: Carlos Mocholí <[email protected]>
1 parent 40e9e99 commit bf2d87f

File tree

13 files changed

+104
-96
lines changed

13 files changed

+104
-96
lines changed

src/lightning_lite/accelerators/cuda.py

Lines changed: 70 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,15 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
import multiprocessing
15-
from typing import Dict, List, Optional, Union
14+
import os
15+
import warnings
16+
from functools import lru_cache
17+
from typing import Dict, List, Optional, Set, Union
1618

1719
import torch
1820

1921
from lightning_lite.accelerators.accelerator import Accelerator
20-
from lightning_lite.strategies.launchers.multiprocessing import _is_forking_disabled
22+
from lightning_lite.utilities.imports import _TORCH_GREATER_EQUAL_1_13
2123

2224

2325
class CUDAAccelerator(Accelerator):
@@ -75,16 +77,20 @@ def _get_all_available_cuda_gpus() -> List[int]:
7577
return list(range(num_cuda_devices()))
7678

7779

80+
@lru_cache(1)
7881
def num_cuda_devices() -> int:
79-
"""Returns the number of GPUs available.
82+
"""Returns the number of available CUDA devices.
8083
8184
Unlike :func:`torch.cuda.device_count`, this function does its best not to create a CUDA context for fork support,
8285
if the platform allows it.
8386
"""
84-
if "fork" not in torch.multiprocessing.get_all_start_methods() or _is_forking_disabled():
87+
if _TORCH_GREATER_EQUAL_1_13:
8588
return torch.cuda.device_count()
86-
with multiprocessing.get_context("fork").Pool(1) as pool:
87-
return pool.apply(torch.cuda.device_count)
89+
90+
# Implementation copied from upstream: https://github.com/pytorch/pytorch/pull/84879
91+
# TODO: Remove once minimum supported PyTorch version is 1.13
92+
nvml_count = _device_count_nvml()
93+
return torch.cuda.device_count() if nvml_count < 0 else nvml_count
8894

8995

9096
def is_cuda_available() -> bool:
@@ -93,7 +99,60 @@ def is_cuda_available() -> bool:
9399
Unlike :func:`torch.cuda.is_available`, this function does its best not to create a CUDA context for fork support,
94100
if the platform allows it.
95101
"""
96-
if "fork" not in torch.multiprocessing.get_all_start_methods() or _is_forking_disabled():
97-
return torch.cuda.is_available()
98-
with multiprocessing.get_context("fork").Pool(1) as pool:
99-
return pool.apply(torch.cuda.is_available)
102+
return num_cuda_devices() > 0
103+
104+
105+
def _parse_visible_devices() -> Set[int]:
106+
"""Implementation copied from upstream: https://github.com/pytorch/pytorch/pull/84879."""
107+
var = os.getenv("CUDA_VISIBLE_DEVICES")
108+
if var is None:
109+
return {x for x in range(64)}
110+
111+
def _strtoul(s: str) -> int:
112+
"""Return -1 or integer sequence string starts with."""
113+
if len(s) == 0:
114+
return -1
115+
for idx, c in enumerate(s):
116+
if not c.isdigit():
117+
break
118+
if idx + 1 == len(s):
119+
idx += 1
120+
return int(s[:idx]) if idx > 0 else -1
121+
122+
# CUDA_VISIBLE_DEVICES uses something like strtoul
123+
# which makes `1gpu2,2ampere` is equivalent to `1,2`
124+
rc: Set[int] = set()
125+
for elem in var.split(","):
126+
rc.add(_strtoul(elem.strip()))
127+
return rc
128+
129+
130+
def _raw_device_count_nvml() -> int:
131+
"""Implementation copied from upstream: https://github.com/pytorch/pytorch/pull/84879."""
132+
from ctypes import c_int, CDLL
133+
134+
nvml_h = CDLL("libnvidia-ml.so.1")
135+
rc = nvml_h.nvmlInit()
136+
if rc != 0:
137+
warnings.warn("Can't initialize NVML")
138+
return -1
139+
dev_arr = (c_int * 1)(-1)
140+
rc = nvml_h.nvmlDeviceGetCount_v2(dev_arr)
141+
if rc != 0:
142+
warnings.warn("Can't get nvml device count")
143+
return -1
144+
del nvml_h
145+
return dev_arr[0]
146+
147+
148+
def _device_count_nvml() -> int:
149+
"""Implementation copied from upstream: https://github.com/pytorch/pytorch/pull/84879."""
150+
try:
151+
raw_cnt = _raw_device_count_nvml()
152+
if raw_cnt <= 0:
153+
return raw_cnt
154+
return len(set(range(raw_cnt)).intersection(_parse_visible_devices()))
155+
except OSError:
156+
return -1
157+
except AttributeError:
158+
return -1

src/lightning_lite/strategies/launchers/multiprocessing.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -63,10 +63,6 @@ def __init__(
6363
f"The start method '{self._start_method}' is not available on this platform. Available methods are:"
6464
f" {', '.join(mp.get_all_start_methods())}"
6565
)
66-
if start_method in ("fork", "forkserver") and _is_forking_disabled():
67-
raise ValueError(
68-
"Forking is disabled in this environment by `PL_DISABLE_FORKING=1`. Choose a different start method."
69-
)
7066

7167
@property
7268
def is_interactive_compatible(self) -> bool:
@@ -170,8 +166,3 @@ def restore(self) -> None:
170166
torch.use_deterministic_algorithms(self.use_deterministic_algorithms)
171167
torch.backends.cudnn.benchmark = self.cudnn_benchmark
172168
_set_rng_states(self.rng_states)
173-
174-
175-
def _is_forking_disabled() -> bool:
176-
"""Returns whether forking is disabled through the environment variable ``PL_DISABLE_FORK``."""
177-
return bool(int(os.environ.get("PL_DISABLE_FORK", "0")))

src/lightning_lite/utilities/device_parser.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,16 @@
1+
# Copyright The PyTorch Lightning team.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
114
from typing import Any, List, MutableSequence, Optional, Tuple, Union
215

316
from lightning_lite.accelerators.cuda import _get_all_available_cuda_gpus

src/pytorch_lightning/CHANGELOG.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,8 +81,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
8181
- Removed fall-back to `LightningEnvironment` when number of SLURM tasks does not correspond to number of processes in Trainer ([#14300](https://github.com/Lightning-AI/lightning/pull/14300))
8282

8383

84-
- The `MLFlowLogger.finalize()` now sets the status to `FAILED` when an exception occurred in `Trainer`, and sets the status to `FINISHED` on successful completion ([#12292](https://github.com/Lightning-AI/lightning/pull/12292))
84+
- Trainer queries the CUDA devices through NVML if available to avoid initializing CUDA before forking, which eliminates the need for the `PL_DISABLE_FORK` environment variable introduced in v1.7.4 ([#14631](https://github.com/Lightning-AI/lightning/issues/14631))
85+
8586

87+
- The `MLFlowLogger.finalize()` now sets the status to `FAILED` when an exception occurred in `Trainer`, and sets the status to `FINISHED` on successful completion ([#12292](https://github.com/Lightning-AI/lightning/pull/12292))
8688

8789

8890
### Deprecated

src/pytorch_lightning/strategies/launchers/multiprocessing.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -68,10 +68,6 @@ def __init__(
6868
f"The start method '{self._start_method}' is not available on this platform. Available methods are:"
6969
f" {', '.join(mp.get_all_start_methods())}"
7070
)
71-
if start_method in ("fork", "forkserver") and _is_forking_disabled():
72-
raise ValueError(
73-
"Forking is disabled in this environment by `PL_DISABLE_FORKING=1`. Choose a different start method."
74-
)
7571

7672
@property
7773
def is_interactive_compatible(self) -> bool:
@@ -287,8 +283,3 @@ def restore(self) -> None:
287283
torch.use_deterministic_algorithms(self.use_deterministic_algorithms)
288284
torch.backends.cudnn.benchmark = self.cudnn_benchmark
289285
_set_rng_states(self.rng_states)
290-
291-
292-
def _is_forking_disabled() -> bool:
293-
"""Returns whether forking is disabled through the environment variable ``PL_DISABLE_FORK``."""
294-
return bool(int(os.environ.get("PL_DISABLE_FORK", "0")))

src/pytorch_lightning/trainer/connectors/accelerator_connector.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,6 @@
7575
TPUSpawnStrategy,
7676
)
7777
from pytorch_lightning.strategies.ddp_spawn import _DDP_FORK_ALIASES
78-
from pytorch_lightning.strategies.launchers.multiprocessing import _is_forking_disabled
7978
from pytorch_lightning.tuner.auto_gpu_select import pick_multiple_gpus
8079
from pytorch_lightning.utilities.exceptions import MisconfigurationException
8180
from pytorch_lightning.utilities.imports import (
@@ -632,10 +631,6 @@ def _check_strategy_and_fallback(self) -> None:
632631
f"You selected `Trainer(strategy='{strategy_flag}')` but process forking is not supported on this"
633632
f" platform. We recommed `Trainer(strategy='ddp_spawn')` instead."
634633
)
635-
if strategy_flag in _DDP_FORK_ALIASES and _is_forking_disabled():
636-
raise ValueError(
637-
"Forking is disabled in this environment by `PL_DISABLE_FORKING=1`. Choose a different strategy."
638-
)
639634
if strategy_flag:
640635
self._strategy_flag = strategy_flag
641636

tests/tests_lite/accelerators/test_cpu.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import pytest
1616
import torch
1717

18-
from lightning_lite.accelerators.cpu import CPUAccelerator
18+
from lightning_lite.accelerators.cpu import CPUAccelerator, parse_cpu_cores
1919

2020

2121
def test_auto_device_count():
@@ -41,3 +41,10 @@ def test_init_device_with_wrong_device_type():
4141
)
4242
def test_get_parallel_devices(devices, expected):
4343
assert CPUAccelerator.get_parallel_devices(devices) == expected
44+
45+
46+
@pytest.mark.parametrize("devices", ([3], -1))
47+
def test_invalid_devices_with_cpu_accelerator(devices):
48+
"""Test invalid device flag raises MisconfigurationException."""
49+
with pytest.raises(TypeError, match="should be an int > 0"):
50+
parse_cpu_cores(devices)

tests/tests_lite/accelerators/test_cuda.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import torch
1818
from tests_lite.helpers.runif import RunIf
1919

20-
from lightning_lite.accelerators.cuda import CUDAAccelerator
20+
from lightning_lite.accelerators.cuda import CUDAAccelerator, is_cuda_available, num_cuda_devices
2121

2222

2323
@mock.patch("lightning_lite.accelerators.cuda.num_cuda_devices", return_value=2)
@@ -51,3 +51,12 @@ def test_get_parallel_devices(devices, expected):
5151
def test_set_cuda_device(set_device_mock):
5252
CUDAAccelerator().setup_device(torch.device("cuda", 1))
5353
set_device_mock.assert_called_once_with(torch.device("cuda", 1))
54+
55+
56+
@mock.patch("lightning_lite.accelerators.cuda._device_count_nvml", return_value=-1)
57+
@mock.patch("torch.cuda.device_count", return_value=100)
58+
def test_num_cuda_devices_without_nvml(*_):
59+
"""Test that if NVML can't be loaded, our helper functions fall back to the default implementation for
60+
determining CUDA availability."""
61+
assert is_cuda_available()
62+
assert num_cuda_devices() == 100

tests/tests_lite/strategies/launchers/test_multiprocessing.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
import os
1514
from unittest import mock
1615
from unittest.mock import ANY, Mock
1716

@@ -35,14 +34,6 @@ def test_multiprocessing_launcher_forking_on_unsupported_platform(_):
3534
_MultiProcessingLauncher(strategy=Mock(), start_method="fork")
3635

3736

38-
@RunIf(skip_windows=True)
39-
@pytest.mark.parametrize("start_method", ["fork", "forkserver"])
40-
@mock.patch.dict(os.environ, {"PL_DISABLE_FORK": "1"}, clear=True)
41-
def test_multiprocessing_launcher_disabled_forking(start_method):
42-
with pytest.raises(ValueError, match="Forking is disabled in this environment"):
43-
_MultiProcessingLauncher(strategy=Mock(), start_method=start_method)
44-
45-
4637
@pytest.mark.parametrize("start_method", ["spawn", "fork"])
4738
@mock.patch("lightning_lite.strategies.launchers.multiprocessing.mp")
4839
def test_multiprocessing_launcher_start_method(mp_mock, start_method):

tests/tests_lite/test_connector.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -692,12 +692,3 @@ def test_gpu_accelerator_no_gpu_backend_found_error(*_):
692692
def test_ddp_fork_on_unsupported_platform(_, strategy):
693693
with pytest.raises(ValueError, match="process forking is not supported on this platform"):
694694
_Connector(strategy=strategy)
695-
696-
697-
@RunIf(skip_windows=True)
698-
@pytest.mark.parametrize("strategy", _DDP_FORK_ALIASES)
699-
@mock.patch.dict(os.environ, {"PL_DISABLE_FORK": "1"}, clear=True)
700-
def test_strategy_choice_ddp_spawn_in_interactive_when_fork_disabled(strategy):
701-
"""Test there is an error when forking is disabled via the environment variable and the user requests fork."""
702-
with pytest.raises(ValueError, match="Forking is disabled in this environment"):
703-
_Connector(devices=2, strategy=strategy)

tests/tests_lite/utilities/test_device_parser.py

Lines changed: 0 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,7 @@
1414
from unittest import mock
1515

1616
import pytest
17-
import torch
1817

19-
from lightning_lite.accelerators.cpu import parse_cpu_cores
20-
from lightning_lite.accelerators.cuda import is_cuda_available, num_cuda_devices
2118
from lightning_lite.utilities import device_parser
2219
from lightning_lite.utilities.exceptions import MisconfigurationException
2320

@@ -87,22 +84,3 @@ def test_parse_gpu_fail_on_non_existent_id_2(_):
8784
def test_parse_gpu_returns_none_when_no_devices_are_available(_, devices):
8885
with pytest.raises(MisconfigurationException):
8986
device_parser.parse_gpu_ids(devices, include_cuda=True)
90-
91-
92-
@pytest.mark.skipif(
93-
"fork" in torch.multiprocessing.get_all_start_methods(), reason="Requires platform without forking support"
94-
)
95-
@mock.patch("torch.cuda.is_available", return_value=True)
96-
@mock.patch("torch.cuda.device_count", return_value=2)
97-
def test_num_cuda_devices_without_forking(*_):
98-
"""This merely tests that on platforms without fork support our helper functions fall back to the default
99-
implementation for determining cuda availability."""
100-
assert is_cuda_available()
101-
assert num_cuda_devices() == 2
102-
103-
104-
@pytest.mark.parametrize("devices", ([3], -1))
105-
def test_invalid_devices_with_cpu_accelerator(devices):
106-
"""Test invalid device flag raises MisconfigurationException."""
107-
with pytest.raises(TypeError, match="should be an int > 0"):
108-
parse_cpu_cores(devices)

tests/tests_pytorch/strategies/launchers/test_multiprocessing.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -11,15 +11,13 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
import os
1514
from unittest import mock
1615
from unittest.mock import ANY, Mock
1716

1817
import pytest
1918
import torch
2019

2120
from pytorch_lightning.strategies.launchers.multiprocessing import _GlobalStateSnapshot, _MultiProcessingLauncher
22-
from tests_pytorch.helpers.runif import RunIf
2321

2422

2523
@mock.patch("pytorch_lightning.strategies.launchers.multiprocessing.mp.get_all_start_methods", return_value=[])
@@ -28,14 +26,6 @@ def test_multiprocessing_launcher_forking_on_unsupported_platform(_):
2826
_MultiProcessingLauncher(strategy=Mock(), start_method="fork")
2927

3028

31-
@RunIf(skip_windows=True)
32-
@pytest.mark.parametrize("start_method", ["fork", "forkserver"])
33-
@mock.patch.dict(os.environ, {"PL_DISABLE_FORK": "1"}, clear=True)
34-
def test_multiprocessing_launcher_disabled_forking(start_method):
35-
with pytest.raises(ValueError, match="Forking is disabled in this environment"):
36-
_MultiProcessingLauncher(strategy=Mock(), start_method=start_method)
37-
38-
3929
@pytest.mark.parametrize("start_method", ["spawn", "fork"])
4030
@mock.patch("pytorch_lightning.strategies.launchers.multiprocessing.mp")
4131
def test_multiprocessing_launcher_start_method(mp_mock, start_method):

tests/tests_pytorch/trainer/connectors/test_accelerator_connector.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -794,12 +794,3 @@ def test_accelerator_specific_checkpoint_io(*_):
794794
def test_ddp_fork_on_unsupported_platform(_, strategy):
795795
with pytest.raises(ValueError, match="process forking is not supported on this platform"):
796796
Trainer(strategy=strategy)
797-
798-
799-
@RunIf(skip_windows=True)
800-
@pytest.mark.parametrize("strategy", _DDP_FORK_ALIASES)
801-
@mock.patch.dict(os.environ, {"PL_DISABLE_FORK": "1"}, clear=True)
802-
def test_strategy_choice_ddp_spawn_in_interactive_when_fork_disabled(strategy):
803-
"""Test there is an error when forking is disabled via the environment variable and the user requests fork."""
804-
with pytest.raises(ValueError, match="Forking is disabled in this environment"):
805-
Trainer(devices=2, strategy=strategy)

0 commit comments

Comments
 (0)