Skip to content

Commit 9605c12

Browse files
authored
[V1][core] Implement pipeline parallel on Ray (#12996)
1 parent 0ccd876 commit 9605c12

File tree

7 files changed

+110
-45
lines changed

7 files changed

+110
-45
lines changed

tests/distributed/test_pipeline_parallel.py

Lines changed: 39 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -40,10 +40,23 @@ class PPTestOptions(NamedTuple):
4040
@dataclass
4141
class PPTestSettings:
4242
parallel_setups: List[ParallelSetup]
43+
# NOTE: the length of distributed_backends and
44+
# vllm_major_versions should be the same, and they
45+
# are first zipped together to iterate over all
46+
# test settings.
4347
distributed_backends: List[str]
48+
# vllm major version: "0" for V0, "1" for V1
49+
vllm_major_versions: List[str]
4450
task: TaskOption
4551
test_options: PPTestOptions
4652

53+
def __post_init__(self):
54+
if len(self.distributed_backends) != len(self.vllm_major_versions):
55+
raise ValueError(
56+
f"Length mismatch: distributed_backends "
57+
f"({len(self.distributed_backends)}) != "
58+
f"vllm_major_versions ({len(self.vllm_major_versions)})")
59+
4760
@staticmethod
4861
def detailed(
4962
*,
@@ -79,7 +92,9 @@ def detailed(
7992
eager_mode=True,
8093
chunked_prefill=False),
8194
],
82-
distributed_backends=["mp", "ray"],
95+
# only ray is supported for V1
96+
distributed_backends=["mp", "ray", "ray"],
97+
vllm_major_versions=["0", "0", "1"],
8398
task=task,
8499
test_options=PPTestOptions(multi_node_only=multi_node_only,
85100
trust_remote_code=trust_remote_code,
@@ -108,6 +123,7 @@ def fast(
108123
chunked_prefill=False),
109124
],
110125
distributed_backends=["mp"],
126+
vllm_major_versions=["0"],
111127
task=task,
112128
test_options=PPTestOptions(multi_node_only=multi_node_only,
113129
trust_remote_code=trust_remote_code,
@@ -120,8 +136,9 @@ def iter_params(self, model_name: str):
120136
opts = self.test_options
121137

122138
for parallel_setup in self.parallel_setups:
123-
for distributed_backend in self.distributed_backends:
124-
yield (model_name, parallel_setup, distributed_backend,
139+
for backend, vllm_major_version in zip(self.distributed_backends,
140+
self.vllm_major_versions):
141+
yield (model_name, parallel_setup, backend, vllm_major_version,
125142
self.task, opts)
126143

127144

@@ -244,6 +261,7 @@ def _compare_tp(
244261
model_name: str,
245262
parallel_setup: ParallelSetup,
246263
distributed_backend: str,
264+
vllm_major_version: str,
247265
task: TaskOption,
248266
test_options: PPTestOptions,
249267
num_gpus_available: int,
@@ -296,10 +314,13 @@ def _compare_tp(
296314
if hf_overrides:
297315
common_args.extend(["--hf-overrides", hf_overrides])
298316

299-
if (distributed_backend == "ray" and tp_size == 2 and pp_size == 2
300-
and chunked_prefill):
301-
# Test Ray ADAG for a subset of the tests
317+
specific_case = tp_size == 2 and pp_size == 2 and chunked_prefill
318+
if distributed_backend == "ray" and (vllm_major_version == "1"
319+
or specific_case):
320+
# For V1, test Ray ADAG for all the tests
321+
# For V0, test Ray ADAG for a subset of the tests
302322
pp_env = {
323+
"VLLM_USE_V1": vllm_major_version,
303324
"VLLM_USE_RAY_COMPILED_DAG": "1",
304325
"VLLM_USE_RAY_SPMD_WORKER": "1",
305326
"VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL": "1",
@@ -348,8 +369,8 @@ def _compare_tp(
348369

349370

350371
@pytest.mark.parametrize(
351-
("model_name", "parallel_setup", "distributed_backend", "task",
352-
"test_options"),
372+
("model_name", "parallel_setup", "distributed_backend",
373+
"vllm_major_version", "task", "test_options"),
353374
[
354375
params for model_name, settings in TEXT_GENERATION_MODELS.items()
355376
for params in settings.iter_params(model_name)
@@ -361,22 +382,24 @@ def test_tp_language_generation(
361382
model_name: str,
362383
parallel_setup: ParallelSetup,
363384
distributed_backend: str,
385+
vllm_major_version: str,
364386
task: TaskOption,
365387
test_options: PPTestOptions,
366388
num_gpus_available,
367389
):
368390
_compare_tp(model_name,
369391
parallel_setup,
370392
distributed_backend,
393+
vllm_major_version,
371394
task,
372395
test_options,
373396
num_gpus_available,
374397
method="generate")
375398

376399

377400
@pytest.mark.parametrize(
378-
("model_name", "parallel_setup", "distributed_backend", "task",
379-
"test_options"),
401+
("model_name", "parallel_setup", "distributed_backend",
402+
"vllm_major_version", "task", "test_options"),
380403
[
381404
params for model_name, settings in EMBEDDING_MODELS.items()
382405
for params in settings.iter_params(model_name)
@@ -388,22 +411,24 @@ def test_tp_language_embedding(
388411
model_name: str,
389412
parallel_setup: ParallelSetup,
390413
distributed_backend: str,
414+
vllm_major_version: str,
391415
task: TaskOption,
392416
test_options: PPTestOptions,
393417
num_gpus_available,
394418
):
395419
_compare_tp(model_name,
396420
parallel_setup,
397421
distributed_backend,
422+
vllm_major_version,
398423
task,
399424
test_options,
400425
num_gpus_available,
401426
method="encode")
402427

403428

404429
@pytest.mark.parametrize(
405-
("model_name", "parallel_setup", "distributed_backend", "task",
406-
"test_options"),
430+
("model_name", "parallel_setup", "distributed_backend",
431+
"vllm_major_version", "task", "test_options"),
407432
[
408433
params for model_name, settings in MULTIMODAL_MODELS.items()
409434
for params in settings.iter_params(model_name)
@@ -415,13 +440,15 @@ def test_tp_multimodal_generation(
415440
model_name: str,
416441
parallel_setup: ParallelSetup,
417442
distributed_backend: str,
443+
vllm_major_version: str,
418444
task: TaskOption,
419445
test_options: PPTestOptions,
420446
num_gpus_available,
421447
):
422448
_compare_tp(model_name,
423449
parallel_setup,
424450
distributed_backend,
451+
vllm_major_version,
425452
task,
426453
test_options,
427454
num_gpus_available,

vllm/executor/ray_utils.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535

3636
class RayWorkerWrapper(WorkerWrapperBase):
3737
"""Ray wrapper for vllm.worker.Worker, allowing Worker to be
38-
lazliy initialized after Ray sets CUDA_VISIBLE_DEVICES."""
38+
lazily initialized after Ray sets CUDA_VISIBLE_DEVICES."""
3939

4040
def __init__(self, *args, **kwargs) -> None:
4141
super().__init__(*args, **kwargs)
@@ -118,7 +118,14 @@ def execute_model(
118118
) -> "ModelRunnerOutput":
119119
self.setup_device_if_necessary()
120120
assert self.worker is not None, "Worker is not initialized"
121-
output = self.worker.model_runner.execute_model(scheduler_output)
121+
if isinstance(scheduler_output, tuple):
122+
scheduler_output, intermediate_tensors = scheduler_output
123+
else:
124+
scheduler_output, intermediate_tensors = scheduler_output, None
125+
output = self.worker.model_runner.execute_model(
126+
scheduler_output, intermediate_tensors)
127+
if isinstance(output, IntermediateTensors):
128+
output = scheduler_output, output
122129
return output
123130

124131
def override_env_vars(self, vars: Dict[str, str]):

vllm/v1/core/kv_cache_utils.py

Lines changed: 27 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -488,7 +488,8 @@ def is_kv_cache_type_uniform(kv_cache_spec: KVCacheSpec) -> bool:
488488

489489
def _get_kv_cache_config_uniform_type(vllm_config: VllmConfig,
490490
kv_cache_spec: KVCacheSpec,
491-
available_memory: int) -> KVCacheConfig:
491+
available_memory: int,
492+
num_layers: int) -> KVCacheConfig:
492493
"""
493494
Generates the KV cache configuration for a model with one type of KV cache.
494495
Divide the available memory equally among all layers.
@@ -497,6 +498,7 @@ def _get_kv_cache_config_uniform_type(vllm_config: VllmConfig,
497498
vllm_config: The global VllmConfig
498499
kv_cache_spec: The kv cache spec of the model
499500
available_memory: Memory available for KV cache in bytes.
501+
num_layers: The number of layers in the model.
500502
501503
Returns:
502504
The generated KVCacheConfig
@@ -506,7 +508,7 @@ def _get_kv_cache_config_uniform_type(vllm_config: VllmConfig,
506508
assert len(page_sizes) == 1
507509
page_size = page_sizes.pop()
508510

509-
num_blocks = int(available_memory // page_size // len(kv_cache_spec))
511+
num_blocks = int(available_memory // page_size // num_layers)
510512
num_blocks = max(num_blocks, 0)
511513

512514
if vllm_config.cache_config.num_gpu_blocks_override is not None:
@@ -536,25 +538,36 @@ def _get_kv_cache_config_uniform_type(vllm_config: VllmConfig,
536538
return kv_cache_config
537539

538540

539-
def get_kv_cache_config(vllm_config: VllmConfig, kv_cache_spec: KVCacheSpec,
540-
available_memory: int) -> KVCacheConfig:
541+
def get_kv_cache_configs(vllm_config: VllmConfig,
542+
kv_cache_specs: List[KVCacheSpec],
543+
available_memory: int) -> List[KVCacheConfig]:
541544
"""
542545
Generates the KV cache configuration for a model
543546
TODO: support hybrid models with more than one type of KV cache.
544547
545548
Args:
546549
vllm_config: The global VllmConfig
547-
kv_cache_spec: The kv cache spec of the model
550+
kv_cache_specs: The kv cache specs of the model
548551
available_memory: Memory available for KV cache in bytes.
549552
550553
Returns:
551-
The generated KVCacheConfig
554+
The generated KVCacheConfigs
552555
"""
553-
check_enough_kv_cache_memory(vllm_config, kv_cache_spec, available_memory)
554-
if is_kv_cache_type_uniform(kv_cache_spec):
555-
# KV cache of all layers are the same, which is true for most models.
556-
# Allocate the same amount of memory for each layer.
557-
return _get_kv_cache_config_uniform_type(vllm_config, kv_cache_spec,
558-
available_memory)
559-
else:
560-
raise NotImplementedError
556+
# Use the max number of layers to conservatively determine
557+
# the number of blocks.
558+
num_layers = max(len(kv_cache_spec) for kv_cache_spec in kv_cache_specs)
559+
kv_cache_configs = []
560+
for kv_cache_spec in kv_cache_specs:
561+
check_enough_kv_cache_memory(vllm_config, kv_cache_spec,
562+
available_memory)
563+
if is_kv_cache_type_uniform(kv_cache_spec):
564+
# KV cache of all layers are the same, which is true for
565+
# most models. Allocate the same amount of memory for
566+
# each layer.
567+
kv_cache_configs.append(
568+
_get_kv_cache_config_uniform_type(vllm_config, kv_cache_spec,
569+
available_memory,
570+
num_layers))
571+
else:
572+
raise NotImplementedError
573+
return kv_cache_configs

vllm/v1/engine/core.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from vllm.transformers_utils.config import (
1717
maybe_register_config_serialize_by_value)
1818
from vllm.utils import get_exception_traceback, zmq_socket_ctx
19-
from vllm.v1.core.kv_cache_utils import get_kv_cache_config
19+
from vllm.v1.core.kv_cache_utils import get_kv_cache_configs
2020
from vllm.v1.core.scheduler import Scheduler
2121
from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest,
2222
EngineCoreRequestType)
@@ -73,20 +73,25 @@ def _initialize_kv_caches(self,
7373
start = time.time()
7474

7575
# Get all kv cache needed by the model
76-
kv_cache_spec = self.model_executor.get_kv_cache_spec()
76+
kv_cache_specs = self.model_executor.get_kv_cache_specs()
7777

7878
# Profiles the peak memory usage of the model to determine how much
7979
# memory can be allocated for kv cache.
80-
availble_gpu_memory = self.model_executor.determine_available_memory()
80+
available_gpu_memory = self.model_executor.determine_available_memory()
8181

8282
# Get the kv cache tensor size
83-
kv_cache_config = get_kv_cache_config(vllm_config, kv_cache_spec,
84-
availble_gpu_memory)
85-
num_gpu_blocks = kv_cache_config.num_blocks
83+
kv_cache_configs = get_kv_cache_configs(vllm_config, kv_cache_specs,
84+
available_gpu_memory)
85+
num_gpu_blocks_set = set(config.num_blocks
86+
for config in kv_cache_configs)
87+
assert len(num_gpu_blocks_set) == 1, (
88+
f"num_gpu_blocks need to be the same across workers, "
89+
f"but they are different: {num_gpu_blocks_set}")
90+
num_gpu_blocks = num_gpu_blocks_set.pop()
8691
num_cpu_blocks = 0
8792

8893
# Initialize kv cache and warmup the execution
89-
self.model_executor.initialize(kv_cache_config)
94+
self.model_executor.initialize(kv_cache_configs)
9095

9196
elapsed = time.time() - start
9297
logger.info(("init engine (profile, create kv cache, "

vllm/v1/executor/abstract.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# SPDX-License-Identifier: Apache-2.0
22

3-
from typing import Type
3+
from typing import List, Type
44

55
from vllm.config import VllmConfig
66
from vllm.executor.executor_base import ExecutorBase
@@ -48,12 +48,12 @@ def get_class(vllm_config: VllmConfig) -> Type["Executor"]:
4848
f"{distributed_executor_backend}")
4949
return executor_class
5050

51-
def initialize(self, kv_cache_config: KVCacheConfig) -> None:
51+
def initialize(self, kv_cache_configs: List[KVCacheConfig]) -> None:
5252
"""
5353
Initialize the KV caches and begin the model execution loop of the
5454
underlying workers.
5555
"""
56-
self.collective_rpc("initialize_cache", args=(kv_cache_config, ))
56+
self.collective_rpc("initialize_cache", args=(kv_cache_configs, ))
5757
self.collective_rpc("compile_or_warm_up_model")
5858

5959
def determine_available_memory(self) -> int: # in bytes
@@ -63,11 +63,9 @@ def determine_available_memory(self) -> int: # in bytes
6363
# operators can be applied to all workers.
6464
return min(output)
6565

66-
def get_kv_cache_spec(self) -> KVCacheSpec:
66+
def get_kv_cache_specs(self) -> List[KVCacheSpec]:
6767
output = self.collective_rpc("get_kv_cache_spec")
68-
for x in output:
69-
assert x == output[0]
70-
return output[0]
68+
return output
7169

7270
def execute_model(
7371
self,

0 commit comments

Comments
 (0)