Skip to content

Commit a02485f

Browse files
youkaichaorasmith
authored andcommitted
[misc] fix cross-node TP (vllm-project#12166)
Signed-off-by: youkaichao <[email protected]>
1 parent a7048b9 commit a02485f

File tree

2 files changed

+36
-24
lines changed

2 files changed

+36
-24
lines changed

vllm/executor/mp_distributed_executor.py

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import asyncio
2+
import os
23
from typing import Any, Callable, List, Optional, Union
34

45
import cloudpickle
@@ -10,8 +11,9 @@
1011
from vllm.logger import init_logger
1112
from vllm.model_executor.layers.sampler import SamplerOutput
1213
from vllm.sequence import ExecuteModelRequest
13-
from vllm.utils import (_run_task_with_lock, get_distributed_init_method,
14-
get_ip, get_open_port, make_async, run_method)
14+
from vllm.utils import (_run_task_with_lock, cuda_device_count_stateless,
15+
get_distributed_init_method, get_ip, get_open_port,
16+
make_async, run_method, update_environment_variables)
1517
from vllm.worker.worker_base import WorkerWrapperBase
1618

1719
logger = init_logger(__name__)
@@ -22,7 +24,39 @@ class MultiprocessingDistributedExecutor(DistributedExecutorBase):
2224

2325
uses_ray: bool = False
2426

27+
def _check_cuda(self) -> None:
28+
"""Check that the number of GPUs is sufficient for the parallel
29+
configuration. Separate from _init_executor to reduce the number of
30+
indented blocks.
31+
"""
32+
parallel_config = self.parallel_config
33+
world_size = parallel_config.world_size
34+
tensor_parallel_size = parallel_config.tensor_parallel_size
35+
36+
cuda_device_count = cuda_device_count_stateless()
37+
# Use confusing message for more common TP-only case.
38+
if tensor_parallel_size > cuda_device_count:
39+
raise RuntimeError(
40+
f"please set tensor_parallel_size ({tensor_parallel_size}) "
41+
f"to less than max local gpu count ({cuda_device_count})")
42+
43+
if world_size > cuda_device_count:
44+
raise RuntimeError(
45+
f"please ensure that world_size ({world_size}) "
46+
f"is less than than max local gpu count ({cuda_device_count})")
47+
48+
# Set CUDA_VISIBLE_DEVICES for the driver, inherited by workers
49+
if "CUDA_VISIBLE_DEVICES" not in os.environ:
50+
update_environment_variables({
51+
"CUDA_VISIBLE_DEVICES": (",".join(map(str, range(world_size))))
52+
})
53+
2554
def _init_executor(self) -> None:
55+
56+
from vllm.platforms import current_platform
57+
if current_platform.is_cuda_alike():
58+
self._check_cuda()
59+
2660
# Create the parallel GPU workers.
2761
world_size = self.parallel_config.world_size
2862
tensor_parallel_size = self.parallel_config.tensor_parallel_size

vllm/platforms/cuda.py

Lines changed: 0 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -139,28 +139,6 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
139139
else:
140140
parallel_config.worker_cls = "vllm.worker.worker.Worker"
141141

142-
world_size = parallel_config.world_size
143-
tensor_parallel_size = parallel_config.tensor_parallel_size
144-
145-
from vllm.utils import (cuda_device_count_stateless,
146-
update_environment_variables)
147-
148-
# Set CUDA_VISIBLE_DEVICES for the driver, inherited by workers
149-
if "CUDA_VISIBLE_DEVICES" not in os.environ:
150-
update_environment_variables({
151-
"CUDA_VISIBLE_DEVICES": (",".join(map(str, range(world_size))))
152-
})
153-
154-
cuda_device_count = cuda_device_count_stateless()
155-
# Use confusing message for more common TP-only case.
156-
assert tensor_parallel_size <= cuda_device_count, (
157-
f"please set tensor_parallel_size ({tensor_parallel_size}) "
158-
f"to less than max local gpu count ({cuda_device_count})")
159-
160-
assert world_size <= cuda_device_count, (
161-
f"please ensure that world_size ({world_size}) "
162-
f"is less than than max local gpu count ({cuda_device_count})")
163-
164142
cache_config = vllm_config.cache_config
165143
if cache_config and cache_config.block_size is None:
166144
cache_config.block_size = 16

0 commit comments

Comments
 (0)