@@ -214,6 +214,7 @@ def recv_kv_caches_and_hidden_states(
214
214
215
215
input_tokens_tensor = model_input .input_tokens
216
216
seq_lens = model_input .attn_metadata .seq_lens
217
+ num_prefill_tokens = model_input .attn_metadata .num_prefill_tokens
217
218
slot_mapping = model_input .attn_metadata .slot_mapping .flatten ()
218
219
219
220
hidden_or_intermediate_states_for_one_req = []
@@ -225,9 +226,21 @@ def recv_kv_caches_and_hidden_states(
225
226
# enumerate different requests
226
227
# FIXME(Kuntai): This impl assumes that all requests are prefill.
227
228
for idx , slen in enumerate (seq_lens ):
228
-
229
229
start_pos = sum (seq_lens [:idx ])
230
230
end_pos = start_pos + slen
231
+
232
+ if start_pos >= num_prefill_tokens :
233
+ # This can happen during inflight batching. See:
234
+ # vllm/worker/model_runner.py::_prepare_model_input_tensors:
235
+ # - input_tokens[:num_prefill_tokens] contains prefill tokens.
236
+ # - input_tokens[num_prefill_tokens:] contains decode tokens.
237
+ logger .warning ("You should set --enable_chunked_prefill=False "
238
+ "and --max_num_batched_tokens "
239
+ "should be equal to max_seq_len_to_capture" )
240
+ bypass_model_exec = False
241
+ assert start_pos == num_prefill_tokens
242
+ break
243
+
231
244
current_tokens = input_tokens_tensor [start_pos :end_pos ]
232
245
num_tokens = slen
233
246
@@ -288,7 +301,7 @@ def recv_kv_caches_and_hidden_states(
288
301
# Here we will fall back to normal model forwarding
289
302
# But optionally you can adjust model_input so that you only do
290
303
# prefilling on those tokens that are missing KV caches.
291
- logger .debug (
304
+ logger .warning (
292
305
"[rank%d]: Failed to receive all KVs and hidden "
293
306
"states, redo model forwarding." , torch .distributed .get_rank ())
294
307
hidden_or_intermediate_states = None
0 commit comments