Skip to content

Commit 6bb2416

Browse files
WoosukKwonjimpang
authored and
jimpang
committed
[V1][PP] Cache Intermediate Tensors (vllm-project#13353)
Signed-off-by: Woosuk Kwon <[email protected]>
1 parent 4588387 commit 6bb2416

File tree

1 file changed

+28
-8
lines changed

1 file changed

+28
-8
lines changed

vllm/v1/worker/gpu_model_runner.py

Lines changed: 28 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import gc
44
import time
5-
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, cast
5+
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union, cast
66

77
import numpy as np
88
import torch
@@ -149,6 +149,7 @@ def __init__(
149149
self.positions = torch.zeros(self.max_num_tokens,
150150
dtype=torch.int64,
151151
device=self.device)
152+
# self.intermediate_tensors # Set after load_model
152153

153154
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
154155
if self.uses_mrope:
@@ -869,7 +870,7 @@ def execute_model(
869870
self,
870871
scheduler_output: "SchedulerOutput",
871872
intermediate_tensors: Optional[IntermediateTensors] = None,
872-
) -> ModelRunnerOutput:
873+
) -> Union[ModelRunnerOutput, torch.Tensor]:
873874
batch_changed = self._update_states(scheduler_output)
874875

875876
if self.is_multimodal_model:
@@ -919,6 +920,14 @@ def execute_model(
919920
else:
920921
positions = self.positions[:num_input_tokens]
921922

923+
if get_pp_group().is_first_rank:
924+
intermediate_tensors = None
925+
else:
926+
intermediate_tensors = IntermediateTensors({
927+
k: v[:num_input_tokens]
928+
for k, v in self.intermediate_tensors.items()
929+
})
930+
922931
# Run the decoder.
923932
# Use persistent buffers for CUDA graphs.
924933
with set_forward_context(attn_metadata, self.vllm_config):
@@ -931,7 +940,9 @@ def execute_model(
931940
inputs_embeds=inputs_embeds,
932941
)
933942
if not get_pp_group().is_last_rank:
943+
# For mid-pipeline stages, return the hidden states.
934944
return hidden_states
945+
935946
hidden_states = hidden_states[:num_scheduled_tokens]
936947
sample_hidden_states = hidden_states[logits_indices]
937948
logits = self.model.compute_logits(sample_hidden_states, None)
@@ -1118,12 +1129,21 @@ def _dummy_run(
11181129
positions = self.mrope_positions[:, :num_tokens]
11191130
else:
11201131
positions = self.positions[:num_tokens]
1121-
intermediate_tensors = None
1122-
if not get_pp_group().is_first_rank:
1123-
intermediate_tensors = self.model.make_empty_intermediate_tensors(
1124-
batch_size=num_tokens,
1125-
dtype=self.model_config.dtype,
1126-
device=self.device)
1132+
1133+
if get_pp_group().is_first_rank:
1134+
intermediate_tensors = None
1135+
else:
1136+
if not hasattr(self, "intermediate_tensors"):
1137+
self.intermediate_tensors = (
1138+
self.model.make_empty_intermediate_tensors(
1139+
batch_size=self.max_num_tokens,
1140+
dtype=self.model_config.dtype,
1141+
device=self.device))
1142+
intermediate_tensors = IntermediateTensors({
1143+
k: v[:num_tokens]
1144+
for k, v in self.intermediate_tensors.items()
1145+
})
1146+
11271147
with set_forward_context(None, self.vllm_config):
11281148
hidden_states = model(
11291149
input_ids=input_ids,

0 commit comments

Comments
 (0)