Skip to content

[Bugfix][Disaggregated] Add a check in send_kv_caches_and_hidden_states and fix the reshape of the KVCache #14369

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 8, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 14 additions & 3 deletions vllm/distributed/kv_transfer/kv_connector/simple_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
"""
Simple KV Cache Connector for Distributed Machine Learning Inference

The SimpleConnector transfers KV caches between prefill vLLM worker (KV cache
The SimpleConnector transfers KV caches between prefill vLLM worker (KV cache
producer) and decode vLLM worker (KV cache consumer) using PyNcclPipe or
MooncakePipe.

Expand Down Expand Up @@ -159,21 +159,32 @@ def send_kv_caches_and_hidden_states(
input_tokens_tensor = model_input.input_tokens
seq_lens = model_input.attn_metadata.seq_lens
slot_mapping_flat = model_input.attn_metadata.slot_mapping.flatten()
num_prefill_tokens = model_input.attn_metadata.num_prefill_tokens
start_layer = model_executable.model.start_layer
end_layer = model_executable.model.end_layer

model_config = model_executable.model.config
num_heads = int(model_config.num_key_value_heads / self.tp_size)
hidden_size = model_config.hidden_size
num_attention_heads = model_config.num_attention_heads
head_size = int(hidden_size / num_attention_heads)
head_size = getattr(model_config, "head_dim",
int(hidden_size // num_attention_heads))

# query_lens contains new KV caches that are added to vLLM.
# so we will send them to decode instance
# FIXME(Kuntai): This assume that all requests are prefill.
for idx, slen in enumerate(seq_lens):
start_pos = sum(seq_lens[:idx])
end_pos = start_pos + slen

if start_pos >= num_prefill_tokens:
# vllm/worker/model_runner.py::_prepare_model_input_tensors:
# - input_tokens[:num_prefill_tokens] contains prefill tokens.
# - input_tokens[num_prefill_tokens:] contains decode tokens.
logger.warning("You have some decode requests while using "
"SimpleConnector. Their KVCache won't be sent.")
break

current_tokens = input_tokens_tensor[start_pos:end_pos]

keys, values = [], []
Expand Down Expand Up @@ -236,7 +247,7 @@ def recv_kv_caches_and_hidden_states(
# - input_tokens[num_prefill_tokens:] contains decode tokens.
logger.warning("You should set --enable_chunked_prefill=False "
"and --max_num_batched_tokens "
"should be equal to max_seq_len_to_capture")
"should be equal to --max_seq_len_to_capture")
bypass_model_exec = False
assert start_pos == num_prefill_tokens
break
Expand Down