Skip to content

Commit b00590c

Browse files
youkaichaowuisawesome
authored andcommitted
[misc] tune some env vars for GB200 (vllm-project#16992)
Signed-off-by: youkaichao <[email protected]>
1 parent 3746928 commit b00590c

File tree

1 file changed

+15
-2
lines changed

1 file changed

+15
-2
lines changed

vllm/env_override.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,21 @@
88
# that interact with vllm workers.
99
# they are executed whenever `import vllm` is called.
1010

11-
# see https://github.com/NVIDIA/nccl/issues/1234
12-
os.environ['NCCL_CUMEM_ENABLE'] = '0'
11+
if not os.path.exists('/dev/nvidia-caps-imex-channels'):
12+
# normally, we disable NCCL_CUMEM_ENABLE because it
13+
# will cost 1~2 GiB GPU memory with cudagraph+allreduce,
14+
# see https://github.com/NVIDIA/nccl/issues/1234
15+
# for more details.
16+
# However, NCCL requires NCCL_CUMEM_ENABLE to work with
17+
# multi-node NVLink, typically on GB200-NVL72 systems.
18+
# The ultimate way to detect multi-node NVLink is to use
19+
# NVML APIs, which are too expensive to call here.
20+
# As an approximation, we check the existence of
21+
# /dev/nvidia-caps-imex-channels, used by
22+
# multi-node NVLink to communicate across nodes.
23+
# This will still cost some GPU memory, but it is worthwhile
24+
# because we can get very fast cross-node bandwidth with NVLink.
25+
os.environ['NCCL_CUMEM_ENABLE'] = '0'
1326

1427
# see https://github.com/vllm-project/vllm/pull/15951
1528
# it avoids unintentional cuda initialization from torch.cuda.is_available()

0 commit comments

Comments
 (0)