Skip to content

Commit 96e0a1a

Browse files
hasB4Klulmer
authored andcommitted
[Bugfix][Disaggregated] patch the inflight batching on the decode node in SimpleConnector to avoid hangs in SimpleBuffer (nccl based) (vllm-project#13987)
Signed-off-by: Mathis Felardos <[email protected]> Signed-off-by: Louis Ulmer <[email protected]>
1 parent 01c5d36 commit 96e0a1a

File tree

1 file changed

+15
-2
lines changed

1 file changed

+15
-2
lines changed

vllm/distributed/kv_transfer/kv_connector/simple_connector.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,7 @@ def recv_kv_caches_and_hidden_states(
214214

215215
input_tokens_tensor = model_input.input_tokens
216216
seq_lens = model_input.attn_metadata.seq_lens
217+
num_prefill_tokens = model_input.attn_metadata.num_prefill_tokens
217218
slot_mapping = model_input.attn_metadata.slot_mapping.flatten()
218219

219220
hidden_or_intermediate_states_for_one_req = []
@@ -225,9 +226,21 @@ def recv_kv_caches_and_hidden_states(
225226
# enumerate different requests
226227
# FIXME(Kuntai): This impl assumes that all requests are prefill.
227228
for idx, slen in enumerate(seq_lens):
228-
229229
start_pos = sum(seq_lens[:idx])
230230
end_pos = start_pos + slen
231+
232+
if start_pos >= num_prefill_tokens:
233+
# This can happen during inflight batching. See:
234+
# vllm/worker/model_runner.py::_prepare_model_input_tensors:
235+
# - input_tokens[:num_prefill_tokens] contains prefill tokens.
236+
# - input_tokens[num_prefill_tokens:] contains decode tokens.
237+
logger.warning("You should set --enable_chunked_prefill=False "
238+
"and --max_num_batched_tokens "
239+
"should be equal to max_seq_len_to_capture")
240+
bypass_model_exec = False
241+
assert start_pos == num_prefill_tokens
242+
break
243+
231244
current_tokens = input_tokens_tensor[start_pos:end_pos]
232245
num_tokens = slen
233246

@@ -288,7 +301,7 @@ def recv_kv_caches_and_hidden_states(
288301
# Here we will fall back to normal model forwarding
289302
# But optionally you can adjust model_input so that you only do
290303
# prefilling on those tokens that are missing KV caches.
291-
logger.debug(
304+
logger.warning(
292305
"[rank%d]: Failed to receive all KVs and hidden "
293306
"states, redo model forwarding.", torch.distributed.get_rank())
294307
hidden_or_intermediate_states = None

0 commit comments

Comments
 (0)