@@ -151,7 +151,8 @@ def __init__(
151
151
self .positions = torch .zeros (self .max_num_tokens ,
152
152
dtype = torch .int64 ,
153
153
device = self .device )
154
- # self.intermediate_tensors # Set after load_model
154
+ # None in the first PP rank. The rest are set after load_model.
155
+ self .intermediate_tensors : Optional [IntermediateTensors ] = None
155
156
156
157
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
157
158
if self .uses_mrope :
@@ -922,6 +923,11 @@ def execute_model(
922
923
if get_pp_group ().is_first_rank :
923
924
intermediate_tensors = None
924
925
else :
926
+ assert intermediate_tensors is not None
927
+ assert self .intermediate_tensors is not None
928
+ for k , v in intermediate_tensors .items ():
929
+ self .intermediate_tensors [k ][:num_input_tokens ].copy_ (
930
+ v [:num_input_tokens ], non_blocking = True )
925
931
intermediate_tensors = IntermediateTensors ({
926
932
k : v [:num_input_tokens ]
927
933
for k , v in self .intermediate_tensors .items ()
@@ -1120,7 +1126,7 @@ def _dummy_run(
1120
1126
if get_pp_group ().is_first_rank :
1121
1127
intermediate_tensors = None
1122
1128
else :
1123
- if not hasattr ( self , " intermediate_tensors" ) :
1129
+ if self . intermediate_tensors is None :
1124
1130
self .intermediate_tensors = (
1125
1131
self .model .make_empty_intermediate_tensors (
1126
1132
batch_size = self .max_num_tokens ,
0 commit comments