Skip to content

[misc] tune some env vars for GB200 #16992

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 23, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 15 additions & 2 deletions vllm/env_override.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,21 @@
# that interact with vllm workers.
# they are executed whenever `import vllm` is called.

# see https://github.com/NVIDIA/nccl/issues/1234
os.environ['NCCL_CUMEM_ENABLE'] = '0'
if not os.path.exists('/dev/nvidia-caps-imex-channels'):
# normally, we disable NCCL_CUMEM_ENABLE because it
# will cost 1~2 GiB GPU memory with cudagraph+allreduce,
# see https://github.com/NVIDIA/nccl/issues/1234
# for more details.
# However, NCCL requires NCCL_CUMEM_ENABLE to work with
# multi-node NVLink, typically on GB200-NVL72 systems.
# The ultimate way to detect multi-node NVLink is to use
# NVML APIs, which are too expensive to call here.
# As an approximation, we check the existence of
# /dev/nvidia-caps-imex-channels, used by
# multi-node NVLink to communicate across nodes.
# This will still cost some GPU memory, but it is worthwhile
# because we can get very fast cross-node bandwidth with NVLink.
os.environ['NCCL_CUMEM_ENABLE'] = '0'

# see https://github.com/vllm-project/vllm/pull/15951
# it avoids unintentional cuda initialization from torch.cuda.is_available()
Expand Down