Skip to content

Commit 777e20a

Browse files
committed
[V1][core] Implement pipeline parallel on Ray
Signed-off-by: Rui Qiao <[email protected]>
1 parent 29f1d47 commit 777e20a

File tree

6 files changed

+68
-31
lines changed

6 files changed

+68
-31
lines changed

tests/distributed/test_pipeline_parallel.py

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
"""
99
import os
1010
from dataclasses import dataclass
11-
from typing import List, Literal, NamedTuple, Optional
11+
from typing import List, Literal, NamedTuple, Optional, Tuple
1212

1313
import pytest
1414

@@ -40,7 +40,8 @@ class PPTestOptions(NamedTuple):
4040
@dataclass
4141
class PPTestSettings:
4242
parallel_setups: List[ParallelSetup]
43-
distributed_backends: List[str]
43+
# vllm major version: "0" for V0, "1" for V1
44+
distributed_backends_and_major_versions: List[Tuple[str, str]]
4445
task: TaskOption
4546
test_options: PPTestOptions
4647

@@ -79,7 +80,8 @@ def detailed(
7980
eager_mode=True,
8081
chunked_prefill=False),
8182
],
82-
distributed_backends=["mp", "ray"],
83+
distributed_backends_and_major_versions=[("mp", "0"), ("ray", "0"),
84+
("ray", "1")],
8385
task=task,
8486
test_options=PPTestOptions(multi_node_only=multi_node_only,
8587
trust_remote_code=trust_remote_code,
@@ -107,7 +109,7 @@ def fast(
107109
eager_mode=True,
108110
chunked_prefill=False),
109111
],
110-
distributed_backends=["mp"],
112+
distributed_backends_and_major_versions=[("mp", "0")],
111113
task=task,
112114
test_options=PPTestOptions(multi_node_only=multi_node_only,
113115
trust_remote_code=trust_remote_code,
@@ -120,9 +122,9 @@ def iter_params(self, model_name: str):
120122
opts = self.test_options
121123

122124
for parallel_setup in self.parallel_setups:
123-
for distributed_backend in self.distributed_backends:
124-
yield (model_name, parallel_setup, distributed_backend,
125-
self.task, opts)
125+
for backend_and_ver in self.distributed_backends_and_major_versions:
126+
yield (model_name, parallel_setup, backend_and_ver, self.task,
127+
opts)
126128

127129

128130
# NOTE: You can adjust tp_base and/or pp_base locally to fit the model in GPU
@@ -296,9 +298,11 @@ def _compare_tp(
296298
if hf_overrides:
297299
common_args.extend(["--hf-overrides", hf_overrides])
298300

299-
if (distributed_backend == "ray" and tp_size == 2 and pp_size == 2
300-
and chunked_prefill):
301-
# Test Ray ADAG for a subset of the tests
301+
vllm_use_v1 = os.getenv("VLLM_USE_V1", "0") == "1"
302+
specific_case = tp_size == 2 and pp_size == 2 and chunked_prefill
303+
if distributed_backend == "ray" and (vllm_use_v1 or specific_case):
304+
# For V1, test Ray ADAG for all the tests
305+
# For V0, test Ray ADAG for a subset of the tests
302306
pp_env = {
303307
"VLLM_USE_RAY_COMPILED_DAG": "1",
304308
"VLLM_USE_RAY_SPMD_WORKER": "1",
@@ -348,8 +352,8 @@ def _compare_tp(
348352

349353

350354
@pytest.mark.parametrize(
351-
("model_name", "parallel_setup", "distributed_backend", "task",
352-
"test_options"),
355+
("model_name", "parallel_setup", "distributed_backend_and_major_version",
356+
"task", "test_options"),
353357
[
354358
params for model_name, settings in TEXT_GENERATION_MODELS.items()
355359
for params in settings.iter_params(model_name)
@@ -358,13 +362,16 @@ def _compare_tp(
358362
)
359363
@fork_new_process_for_each_test
360364
def test_tp_language_generation(
365+
monkeypatch,
361366
model_name: str,
362367
parallel_setup: ParallelSetup,
363-
distributed_backend: str,
368+
distributed_backend_and_major_version: Tuple[str, str],
364369
task: TaskOption,
365370
test_options: PPTestOptions,
366371
num_gpus_available,
367372
):
373+
distributed_backend, major_version = distributed_backend_and_major_version
374+
monkeypatch.setenv('VLLM_USE_V1', major_version)
368375
_compare_tp(model_name,
369376
parallel_setup,
370377
distributed_backend,

vllm/executor/ray_utils.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535

3636
class RayWorkerWrapper(WorkerWrapperBase):
3737
"""Ray wrapper for vllm.worker.Worker, allowing Worker to be
38-
lazliy initialized after Ray sets CUDA_VISIBLE_DEVICES."""
38+
lazily initialized after Ray sets CUDA_VISIBLE_DEVICES."""
3939

4040
def __init__(self, *args, **kwargs) -> None:
4141
super().__init__(*args, **kwargs)
@@ -118,7 +118,14 @@ def execute_model(
118118
) -> "ModelRunnerOutput":
119119
self.setup_device_if_necessary()
120120
assert self.worker is not None, "Worker is not initialized"
121-
output = self.worker.model_runner.execute_model(scheduler_output)
121+
if isinstance(scheduler_output, tuple):
122+
scheduler_output, intermediate_tensors = scheduler_output
123+
else:
124+
scheduler_output, intermediate_tensors = scheduler_output, None
125+
output = self.worker.model_runner.execute_model(
126+
scheduler_output, intermediate_tensors)
127+
if isinstance(output, IntermediateTensors):
128+
output = scheduler_output, output
122129
return output
123130

124131
def override_env_vars(self, vars: Dict[str, str]):

vllm/v1/engine/core.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -71,20 +71,30 @@ def _initialize_kv_caches(self,
7171
start = time.time()
7272

7373
# Get all kv cache needed by the model
74-
kv_cache_spec = self.model_executor.get_kv_cache_spec()
74+
kv_cache_specs = self.model_executor.get_kv_cache_specs()
7575

7676
# Profiles the peak memory usage of the model to determine how much
7777
# memory can be allocated for kv cache.
78-
availble_gpu_memory = self.model_executor.determine_available_memory()
78+
available_gpu_memory = self.model_executor.determine_available_memory()
7979

8080
# Get the kv cache tensor size
81-
kv_cache_config = get_kv_cache_config(vllm_config, kv_cache_spec,
82-
availble_gpu_memory)
83-
num_gpu_blocks = kv_cache_config.num_blocks
81+
kv_cache_configs = []
82+
num_gpu_blocks = None
83+
for kv_cache_spec in kv_cache_specs:
84+
kv_cache_config = get_kv_cache_config(vllm_config, kv_cache_spec,
85+
available_gpu_memory)
86+
kv_cache_configs.append(kv_cache_config)
87+
if num_gpu_blocks is None:
88+
num_gpu_blocks = kv_cache_config.num_blocks
89+
elif num_gpu_blocks != kv_cache_config.num_blocks:
90+
raise NotImplementedError(
91+
"num_gpu_blocks need to be the same across workers: "
92+
f"{num_gpu_blocks} != {kv_cache_config.num_blocks}")
93+
assert num_gpu_blocks is not None
8494
num_cpu_blocks = 0
8595

8696
# Initialize kv cache and warmup the execution
87-
self.model_executor.initialize(kv_cache_config)
97+
self.model_executor.initialize(kv_cache_configs)
8898

8999
elapsed = time.time() - start
90100
logger.info(("init engine (profile, create kv cache, "

vllm/v1/executor/abstract.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# SPDX-License-Identifier: Apache-2.0
22

3-
from typing import Type
3+
from typing import List, Type
44

55
from vllm.config import VllmConfig
66
from vllm.executor.executor_base import ExecutorBase
@@ -48,12 +48,12 @@ def get_class(vllm_config: VllmConfig) -> Type["Executor"]:
4848
f"{distributed_executor_backend}")
4949
return executor_class
5050

51-
def initialize(self, kv_cache_config: KVCacheConfig) -> None:
51+
def initialize(self, kv_cache_configs: List[KVCacheConfig]) -> None:
5252
"""
5353
Initialize the KV caches and begin the model execution loop of the
5454
underlying workers.
5555
"""
56-
self.collective_rpc("initialize_cache", args=(kv_cache_config, ))
56+
self.collective_rpc("initialize_cache", args=(kv_cache_configs, ))
5757
self.collective_rpc("compile_or_warm_up_model")
5858

5959
def determine_available_memory(self) -> int: # in bytes
@@ -63,11 +63,9 @@ def determine_available_memory(self) -> int: # in bytes
6363
# operators can be applied to all workers.
6464
return min(output)
6565

66-
def get_kv_cache_spec(self) -> KVCacheSpec:
66+
def get_kv_cache_specs(self) -> List[KVCacheSpec]:
6767
output = self.collective_rpc("get_kv_cache_spec")
68-
for x in output:
69-
assert x == output[0]
70-
return output[0]
68+
return output
7169

7270
def execute_model(
7371
self,

vllm/v1/worker/gpu_model_runner.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from vllm.attention.backends.abstract import AttentionType
1313
from vllm.attention.layer import Attention
1414
from vllm.config import CompilationLevel, VllmConfig
15-
from vllm.distributed.parallel_state import graph_capture
15+
from vllm.distributed.parallel_state import get_pp_group, graph_capture
1616
from vllm.forward_context import set_forward_context
1717
from vllm.inputs import INPUT_REGISTRY
1818
from vllm.logger import init_logger
@@ -21,6 +21,7 @@
2121
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
2222
from vllm.multimodal.utils import group_mm_inputs_by_modality
2323
from vllm.sampling_params import SamplingType
24+
from vllm.sequence import IntermediateTensors
2425
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
2526
LayerBlockType, cdiv, is_pin_memory_available)
2627
from vllm.v1.attention.backends.flash_attn import (FlashAttentionBackend,
@@ -773,6 +774,7 @@ def get_model(self) -> nn.Module:
773774
def execute_model(
774775
self,
775776
scheduler_output: "SchedulerOutput",
777+
intermediate_tensors: Optional[IntermediateTensors] = None,
776778
) -> ModelRunnerOutput:
777779
batch_changed = self._update_states(scheduler_output)
778780

@@ -831,8 +833,11 @@ def execute_model(
831833
positions=positions,
832834
kv_caches=self.kv_caches,
833835
attn_metadata=None,
836+
intermediate_tensors=intermediate_tensors,
834837
inputs_embeds=inputs_embeds,
835838
)
839+
if not get_pp_group().is_last_rank:
840+
return hidden_states
836841
hidden_states = hidden_states[:num_scheduled_tokens]
837842
sample_hidden_states = hidden_states[logits_indices]
838843
logits = self.model.compute_logits(sample_hidden_states, None)
@@ -1007,12 +1012,19 @@ def _dummy_run(
10071012
positions = self.mrope_positions[:, :num_tokens]
10081013
else:
10091014
positions = self.positions[:num_tokens]
1015+
intermediate_tensors = None
1016+
if not get_pp_group().is_first_rank:
1017+
intermediate_tensors = self.model.make_empty_intermediate_tensors(
1018+
batch_size=num_tokens,
1019+
dtype=self.model_config.dtype,
1020+
device=self.device)
10101021
with set_forward_context(None, self.vllm_config):
10111022
hidden_states = model(
10121023
input_ids=input_ids,
10131024
positions=positions,
10141025
kv_caches=kv_caches,
10151026
attn_metadata=None,
1027+
intermediate_tensors=intermediate_tensors,
10161028
inputs_embeds=inputs_embeds,
10171029
)
10181030
return hidden_states
@@ -1142,6 +1154,8 @@ def profile_run(self) -> None:
11421154
# Trigger compilation for general shape.
11431155
hidden_states = self._dummy_run(self.max_num_tokens,
11441156
dummy_kv_caches)
1157+
if not get_pp_group().is_last_rank:
1158+
return hidden_states
11451159
hidden_states = hidden_states[logit_indices]
11461160
logits = self.model.compute_logits(hidden_states, None)
11471161
# TODO(woosuk): Consider the memory usage of the sampler.

vllm/v1/worker/gpu_worker.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
"""A GPU worker class."""
33
import gc
44
import os
5-
from typing import TYPE_CHECKING, Optional
5+
from typing import TYPE_CHECKING, List, Optional
66

77
import torch
88
import torch.distributed
@@ -195,8 +195,9 @@ def determine_available_memory(self) -> int:
195195
def get_kv_cache_spec(self) -> KVCacheSpec:
196196
return self.model_runner.get_kv_cache_spec()
197197

198-
def initialize_cache(self, kv_cache_config: KVCacheConfig) -> None:
198+
def initialize_cache(self, kv_cache_configs: List[KVCacheConfig]) -> None:
199199
"""Allocate GPU KV cache with the specified kv_cache_config."""
200+
kv_cache_config = kv_cache_configs[self.rank]
200201
if self.vllm_config.model_config.enable_sleep_mode:
201202
allocator = CuMemAllocator.get_instance()
202203
context = allocator.use_memory_pool(tag="kv_cache")

0 commit comments

Comments
 (0)