2
2
3
3
import gc
4
4
import time
5
- from typing import TYPE_CHECKING , Dict , List , Optional , Tuple , cast
5
+ from typing import TYPE_CHECKING , Dict , List , Optional , Tuple , Union , cast
6
6
7
7
import numpy as np
8
8
import torch
@@ -149,6 +149,7 @@ def __init__(
149
149
self .positions = torch .zeros (self .max_num_tokens ,
150
150
dtype = torch .int64 ,
151
151
device = self .device )
152
+ # self.intermediate_tensors # Set after load_model
152
153
153
154
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
154
155
if self .uses_mrope :
@@ -869,7 +870,7 @@ def execute_model(
869
870
self ,
870
871
scheduler_output : "SchedulerOutput" ,
871
872
intermediate_tensors : Optional [IntermediateTensors ] = None ,
872
- ) -> ModelRunnerOutput :
873
+ ) -> Union [ ModelRunnerOutput , torch . Tensor ] :
873
874
batch_changed = self ._update_states (scheduler_output )
874
875
875
876
if self .is_multimodal_model :
@@ -919,6 +920,14 @@ def execute_model(
919
920
else :
920
921
positions = self .positions [:num_input_tokens ]
921
922
923
+ if get_pp_group ().is_first_rank :
924
+ intermediate_tensors = None
925
+ else :
926
+ intermediate_tensors = IntermediateTensors ({
927
+ k : v [:num_input_tokens ]
928
+ for k , v in self .intermediate_tensors .items ()
929
+ })
930
+
922
931
# Run the decoder.
923
932
# Use persistent buffers for CUDA graphs.
924
933
with set_forward_context (attn_metadata , self .vllm_config ):
@@ -931,7 +940,9 @@ def execute_model(
931
940
inputs_embeds = inputs_embeds ,
932
941
)
933
942
if not get_pp_group ().is_last_rank :
943
+ # For mid-pipeline stages, return the hidden states.
934
944
return hidden_states
945
+
935
946
hidden_states = hidden_states [:num_scheduled_tokens ]
936
947
sample_hidden_states = hidden_states [logits_indices ]
937
948
logits = self .model .compute_logits (sample_hidden_states , None )
@@ -1118,12 +1129,21 @@ def _dummy_run(
1118
1129
positions = self .mrope_positions [:, :num_tokens ]
1119
1130
else :
1120
1131
positions = self .positions [:num_tokens ]
1121
- intermediate_tensors = None
1122
- if not get_pp_group ().is_first_rank :
1123
- intermediate_tensors = self .model .make_empty_intermediate_tensors (
1124
- batch_size = num_tokens ,
1125
- dtype = self .model_config .dtype ,
1126
- device = self .device )
1132
+
1133
+ if get_pp_group ().is_first_rank :
1134
+ intermediate_tensors = None
1135
+ else :
1136
+ if not hasattr (self , "intermediate_tensors" ):
1137
+ self .intermediate_tensors = (
1138
+ self .model .make_empty_intermediate_tensors (
1139
+ batch_size = self .max_num_tokens ,
1140
+ dtype = self .model_config .dtype ,
1141
+ device = self .device ))
1142
+ intermediate_tensors = IntermediateTensors ({
1143
+ k : v [:num_tokens ]
1144
+ for k , v in self .intermediate_tensors .items ()
1145
+ })
1146
+
1127
1147
with set_forward_context (None , self .vllm_config ):
1128
1148
hidden_states = model (
1129
1149
input_ids = input_ids ,
0 commit comments