Skip to content

Commit 676a999

Browse files
njhillsahilsuneja1
andauthored
[Core] Add MultiprocessingGPUExecutor (#4539)
Co-authored-by: SAHIL SUNEJA <[email protected]>
1 parent dc72402 commit 676a999

11 files changed

+225
-39
lines changed

.buildkite/test-pipeline.yaml

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,14 @@ steps:
3434
mirror_hardwares: [amd]
3535
commands:
3636
- pytest -v -s distributed/test_pynccl_library.py
37-
- TEST_DIST_MODEL=facebook/opt-125m pytest -v -s distributed/test_basic_distributed_correctness.py
38-
- TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf pytest -v -s distributed/test_basic_distributed_correctness.py
39-
- TEST_DIST_MODEL=facebook/opt-125m pytest -v -s distributed/test_chunked_prefill_distributed.py
40-
- TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf pytest -v -s distributed/test_chunked_prefill_distributed.py
37+
- TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py
38+
- TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py
39+
- TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_chunked_prefill_distributed.py
40+
- TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_chunked_prefill_distributed.py
41+
- TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_basic_distributed_correctness.py
42+
- TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_basic_distributed_correctness.py
43+
- TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_chunked_prefill_distributed.py
44+
- TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_chunked_prefill_distributed.py
4145

4246
- label: Distributed Tests (Multiple Groups)
4347
working_dir: "/vllm-workspace/tests"

tests/distributed/test_basic_distributed_correctness.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
MODELS = [
2121
os.environ["TEST_DIST_MODEL"],
2222
]
23+
DISTRIBUTED_EXECUTOR_BACKEND = "DISTRIBUTED_EXECUTOR_BACKEND"
2324
VLLM_ATTENTION_BACKEND = "VLLM_ATTENTION_BACKEND"
2425

2526

@@ -36,19 +37,21 @@ def test_models(
3637
dtype: str,
3738
max_tokens: int,
3839
) -> None:
39-
enforce_eager = False
40+
distributed_executor_backend = os.getenv(DISTRIBUTED_EXECUTOR_BACKEND)
41+
4042
backend_by_env_var = os.getenv(VLLM_ATTENTION_BACKEND)
41-
if backend_by_env_var == "FLASHINFER":
42-
enforce_eager = True
43+
enforce_eager = backend_by_env_var == "FLASHINFER"
4344

4445
hf_model = hf_runner(model, dtype=dtype)
4546
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
4647
del hf_model
4748

48-
vllm_model = vllm_runner(model,
49-
dtype=dtype,
50-
tensor_parallel_size=2,
51-
enforce_eager=enforce_eager)
49+
vllm_model = vllm_runner(
50+
model,
51+
dtype=dtype,
52+
tensor_parallel_size=2,
53+
enforce_eager=enforce_eager,
54+
distributed_executor_backend=distributed_executor_backend)
5255
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
5356
del vllm_model
5457

tests/distributed/test_chunked_prefill_distributed.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
MODELS = [
2020
os.environ["TEST_DIST_MODEL"],
2121
]
22+
DISTRIBUTED_EXECUTOR_BACKEND = "DISTRIBUTED_EXECUTOR_BACKEND"
2223

2324

2425
@pytest.mark.skipif(torch.cuda.device_count() < 2,
@@ -36,6 +37,8 @@ def test_models(
3637
max_tokens: int,
3738
chunked_prefill_token_size: int,
3839
) -> None:
40+
distributed_executor_backend = os.getenv(DISTRIBUTED_EXECUTOR_BACKEND)
41+
3942
# Add a chunked prefill config.
4043
max_num_seqs = min(chunked_prefill_token_size, 256)
4144
assert chunked_prefill_token_size != -1
@@ -53,6 +56,7 @@ def test_models(
5356
max_num_seqs=max_num_seqs,
5457
enable_chunked_prefill=enable_chunked_prefill,
5558
max_num_batched_tokens=max_num_batched_tokens,
59+
distributed_executor_backend=distributed_executor_backend,
5660
)
5761
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
5862
del vllm_model

tests/lora/test_mixtral.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,7 @@ def test_mixtral_lora(mixtral_lora_files, tp_size):
3838
enable_lora=True,
3939
max_num_seqs=16,
4040
max_loras=4,
41-
tensor_parallel_size=tp_size,
42-
worker_use_ray=True)
41+
tensor_parallel_size=tp_size)
4342

4443
expected_lora_output = [
4544
"give_opinion(name[SpellForce 3], release_year[2017], developer[Grimlore Games], rating[poor])", # noqa: E501

vllm/config.py

Lines changed: 29 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -521,9 +521,7 @@ class ParallelConfig:
521521
Args:
522522
pipeline_parallel_size: Number of pipeline parallel groups.
523523
tensor_parallel_size: Number of tensor parallel groups.
524-
worker_use_ray: Whether to use Ray for model workers. Will be set to
525-
True if either pipeline_parallel_size or tensor_parallel_size is
526-
greater than 1.
524+
worker_use_ray: Deprecated, use distributed_executor_backend instead.
527525
max_parallel_loading_workers: Maximum number of multiple batches
528526
when load model sequentially. To avoid RAM OOM when using tensor
529527
parallel and large models.
@@ -533,37 +531,57 @@ class ParallelConfig:
533531
If None, will use synchronous tokenization.
534532
ray_workers_use_nsight: Whether to profile Ray workers with nsight, see
535533
https://docs.ray.io/en/latest/ray-observability/user-guides/profiling.html#profiling-nsight-profiler.
534+
distributed_executor_backend: Backend to use for distributed model
535+
workers, either "ray" or "mp" (multiprocessing). If either
536+
pipeline_parallel_size or tensor_parallel_size is greater than 1,
537+
will default to "ray" if Ray is installed or "mp" otherwise.
536538
"""
537539

538540
def __init__(
539541
self,
540542
pipeline_parallel_size: int,
541543
tensor_parallel_size: int,
542-
worker_use_ray: bool,
544+
worker_use_ray: Optional[bool] = None,
543545
max_parallel_loading_workers: Optional[int] = None,
544546
disable_custom_all_reduce: bool = False,
545547
tokenizer_pool_config: Optional[TokenizerPoolConfig] = None,
546548
ray_workers_use_nsight: bool = False,
547549
placement_group: Optional["PlacementGroup"] = None,
550+
distributed_executor_backend: Optional[str] = None,
548551
) -> None:
549552
self.pipeline_parallel_size = pipeline_parallel_size
550553
self.tensor_parallel_size = tensor_parallel_size
551-
self.worker_use_ray = worker_use_ray
554+
self.distributed_executor_backend = distributed_executor_backend
552555
self.max_parallel_loading_workers = max_parallel_loading_workers
553556
self.disable_custom_all_reduce = disable_custom_all_reduce
554557
self.tokenizer_pool_config = tokenizer_pool_config
555558
self.ray_workers_use_nsight = ray_workers_use_nsight
556559
self.placement_group = placement_group
557560

558561
self.world_size = pipeline_parallel_size * self.tensor_parallel_size
559-
if self.world_size > 1:
560-
self.worker_use_ray = True
562+
if worker_use_ray:
563+
if self.distributed_executor_backend is None:
564+
self.distributed_executor_backend = "ray"
565+
elif self.distributed_executor_backend != "ray":
566+
raise ValueError(f"worker-use-ray can't be used with "
567+
f"distributed executor backend "
568+
f"'{self.distributed_executor_backend}'.")
569+
570+
if self.distributed_executor_backend is None and self.world_size > 1:
571+
from vllm.executor import ray_utils
572+
ray_found = ray_utils.ray is not None
573+
self.distributed_executor_backend = "ray" if ray_found else "mp"
574+
561575
self._verify_args()
562576

563577
def _verify_args(self) -> None:
564578
if self.pipeline_parallel_size > 1:
565579
raise NotImplementedError(
566580
"Pipeline parallelism is not supported yet.")
581+
if self.distributed_executor_backend not in ("ray", "mp", None):
582+
raise ValueError(
583+
"Unrecognized distributed executor backend. Supported values "
584+
"are 'ray' or 'mp'.")
567585
if not self.disable_custom_all_reduce and self.world_size > 1:
568586
if is_hip():
569587
self.disable_custom_all_reduce = True
@@ -575,7 +593,8 @@ def _verify_args(self) -> None:
575593
logger.info(
576594
"Disabled the custom all-reduce kernel because it is not "
577595
"supported with pipeline parallelism.")
578-
if self.ray_workers_use_nsight and not self.worker_use_ray:
596+
if self.ray_workers_use_nsight and (
597+
not self.distributed_executor_backend == "ray"):
579598
raise ValueError("Unable to use nsight profiling unless workers "
580599
"run with Ray.")
581600

@@ -887,7 +906,8 @@ def create_draft_parallel_config(
887906
pipeline_parallel_size=target_parallel_config.
888907
pipeline_parallel_size,
889908
tensor_parallel_size=target_parallel_config.tensor_parallel_size,
890-
worker_use_ray=target_parallel_config.worker_use_ray,
909+
distributed_executor_backend=target_parallel_config.
910+
distributed_executor_backend,
891911
max_parallel_loading_workers=target_parallel_config.
892912
max_parallel_loading_workers,
893913
disable_custom_all_reduce=target_parallel_config.

vllm/engine/arg_utils.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ class EngineArgs:
3434
seed: int = 0
3535
max_model_len: Optional[int] = None
3636
worker_use_ray: bool = False
37+
distributed_executor_backend: Optional[str] = None
3738
pipeline_parallel_size: int = 1
3839
tensor_parallel_size: int = 1
3940
max_parallel_loading_workers: Optional[int] = None
@@ -221,10 +222,17 @@ def add_cli_args(
221222
' Can be overridden per request via guided_decoding_backend'
222223
' parameter.')
223224
# Parallel arguments
224-
parser.add_argument('--worker-use-ray',
225-
action='store_true',
226-
help='Use Ray for distributed serving, will be '
227-
'automatically set when using more than 1 GPU.')
225+
parser.add_argument(
226+
'--distributed-executor-backend',
227+
choices=['ray', 'mp'],
228+
default=EngineArgs.distributed_executor_backend,
229+
help='Backend to use for distributed serving. When more than 1 GPU '
230+
'is used, will be automatically set to "ray" if installed '
231+
'or "mp" (multiprocessing) otherwise.')
232+
parser.add_argument(
233+
'--worker-use-ray',
234+
action='store_true',
235+
help='Deprecated, use --distributed-executor-backend=ray.')
228236
parser.add_argument('--pipeline-parallel-size',
229237
'-pp',
230238
type=int,

vllm/engine/async_llm_engine.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -348,27 +348,31 @@ def from_engine_args(
348348
"""Creates an async LLM engine from the engine arguments."""
349349
# Create the engine configs.
350350
engine_config = engine_args.create_engine_config()
351+
distributed_executor_backend = (
352+
engine_config.parallel_config.distributed_executor_backend)
351353

352354
if engine_config.device_config.device_type == "neuron":
353355
from vllm.executor.neuron_executor import NeuronExecutorAsync
354356
executor_class = NeuronExecutorAsync
355357
elif engine_config.device_config.device_type == "cpu":
356-
assert not engine_config.parallel_config.worker_use_ray, (
357-
"Ray is not supported with the CPU backend.")
358+
assert distributed_executor_backend is None, (
359+
"Distributed execution is not supported with the CPU backend.")
358360
from vllm.executor.cpu_executor import CPUExecutorAsync
359361
executor_class = CPUExecutorAsync
360-
elif engine_config.parallel_config.worker_use_ray:
362+
elif distributed_executor_backend == "ray":
361363
initialize_ray_cluster(engine_config.parallel_config)
362364
from vllm.executor.ray_gpu_executor import RayGPUExecutorAsync
363365
executor_class = RayGPUExecutorAsync
366+
elif distributed_executor_backend == "mp":
367+
from vllm.executor.multiproc_gpu_executor import (
368+
MultiprocessingGPUExecutorAsync)
369+
executor_class = MultiprocessingGPUExecutorAsync
364370
else:
365-
assert engine_config.parallel_config.world_size == 1, (
366-
"Ray is required if parallel_config.world_size > 1.")
367371
from vllm.executor.gpu_executor import GPUExecutorAsync
368372
executor_class = GPUExecutorAsync
369373
# Create the async LLM engine.
370374
engine = cls(
371-
engine_config.parallel_config.worker_use_ray,
375+
distributed_executor_backend == "ray",
372376
engine_args.engine_use_ray,
373377
**engine_config.to_dict(),
374378
executor_class=executor_class,

vllm/engine/llm_engine.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -274,6 +274,8 @@ def from_engine_args(
274274
"""Creates an LLM engine from the engine arguments."""
275275
# Create the engine configs.
276276
engine_config = engine_args.create_engine_config()
277+
distributed_executor_backend = (
278+
engine_config.parallel_config.distributed_executor_backend)
277279

278280
# Initialize the cluster and specify the executor class.
279281
if engine_config.device_config.device_type == "neuron":
@@ -282,13 +284,15 @@ def from_engine_args(
282284
elif engine_config.device_config.device_type == "cpu":
283285
from vllm.executor.cpu_executor import CPUExecutor
284286
executor_class = CPUExecutor
285-
elif engine_config.parallel_config.worker_use_ray:
287+
elif distributed_executor_backend == "ray":
286288
initialize_ray_cluster(engine_config.parallel_config)
287289
from vllm.executor.ray_gpu_executor import RayGPUExecutor
288290
executor_class = RayGPUExecutor
291+
elif distributed_executor_backend == "mp":
292+
from vllm.executor.multiproc_gpu_executor import (
293+
MultiprocessingGPUExecutor)
294+
executor_class = MultiprocessingGPUExecutor
289295
else:
290-
assert engine_config.parallel_config.world_size == 1, (
291-
"Ray is required if parallel_config.world_size > 1.")
292296
from vllm.executor.gpu_executor import GPUExecutor
293297
executor_class = GPUExecutor
294298

0 commit comments

Comments
 (0)