Skip to content

Commit 980385f

Browse files
authored
[Bugfix][Disaggregated] Add a check in send_kv_caches_and_hidden_states and fix the reshape of the KVCache (#14369)
Signed-off-by: Mathis Felardos <[email protected]>
1 parent ca7a2d5 commit 980385f

File tree

1 file changed

+14
-3
lines changed

1 file changed

+14
-3
lines changed

vllm/distributed/kv_transfer/kv_connector/simple_connector.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
"""
33
Simple KV Cache Connector for Distributed Machine Learning Inference
44
5-
The SimpleConnector transfers KV caches between prefill vLLM worker (KV cache
5+
The SimpleConnector transfers KV caches between prefill vLLM worker (KV cache
66
producer) and decode vLLM worker (KV cache consumer) using PyNcclPipe or
77
MooncakePipe.
88
@@ -159,21 +159,32 @@ def send_kv_caches_and_hidden_states(
159159
input_tokens_tensor = model_input.input_tokens
160160
seq_lens = model_input.attn_metadata.seq_lens
161161
slot_mapping_flat = model_input.attn_metadata.slot_mapping.flatten()
162+
num_prefill_tokens = model_input.attn_metadata.num_prefill_tokens
162163
start_layer = model_executable.model.start_layer
163164
end_layer = model_executable.model.end_layer
164165

165166
model_config = model_executable.model.config
166167
num_heads = int(model_config.num_key_value_heads / self.tp_size)
167168
hidden_size = model_config.hidden_size
168169
num_attention_heads = model_config.num_attention_heads
169-
head_size = int(hidden_size / num_attention_heads)
170+
head_size = getattr(model_config, "head_dim",
171+
int(hidden_size // num_attention_heads))
170172

171173
# query_lens contains new KV caches that are added to vLLM.
172174
# so we will send them to decode instance
173175
# FIXME(Kuntai): This assume that all requests are prefill.
174176
for idx, slen in enumerate(seq_lens):
175177
start_pos = sum(seq_lens[:idx])
176178
end_pos = start_pos + slen
179+
180+
if start_pos >= num_prefill_tokens:
181+
# vllm/worker/model_runner.py::_prepare_model_input_tensors:
182+
# - input_tokens[:num_prefill_tokens] contains prefill tokens.
183+
# - input_tokens[num_prefill_tokens:] contains decode tokens.
184+
logger.warning("You have some decode requests while using "
185+
"SimpleConnector. Their KVCache won't be sent.")
186+
break
187+
177188
current_tokens = input_tokens_tensor[start_pos:end_pos]
178189

179190
keys, values = [], []
@@ -236,7 +247,7 @@ def recv_kv_caches_and_hidden_states(
236247
# - input_tokens[num_prefill_tokens:] contains decode tokens.
237248
logger.warning("You should set --enable_chunked_prefill=False "
238249
"and --max_num_batched_tokens "
239-
"should be equal to max_seq_len_to_capture")
250+
"should be equal to --max_seq_len_to_capture")
240251
bypass_model_exec = False
241252
assert start_pos == num_prefill_tokens
242253
break

0 commit comments

Comments
 (0)