|
2 | 2 | """
|
3 | 3 | Simple KV Cache Connector for Distributed Machine Learning Inference
|
4 | 4 |
|
5 |
| -The SimpleConnector transfers KV caches between prefill vLLM worker (KV cache |
| 5 | +The SimpleConnector transfers KV caches between prefill vLLM worker (KV cache |
6 | 6 | producer) and decode vLLM worker (KV cache consumer) using PyNcclPipe or
|
7 | 7 | MooncakePipe.
|
8 | 8 |
|
@@ -159,21 +159,32 @@ def send_kv_caches_and_hidden_states(
|
159 | 159 | input_tokens_tensor = model_input.input_tokens
|
160 | 160 | seq_lens = model_input.attn_metadata.seq_lens
|
161 | 161 | slot_mapping_flat = model_input.attn_metadata.slot_mapping.flatten()
|
| 162 | + num_prefill_tokens = model_input.attn_metadata.num_prefill_tokens |
162 | 163 | start_layer = model_executable.model.start_layer
|
163 | 164 | end_layer = model_executable.model.end_layer
|
164 | 165 |
|
165 | 166 | model_config = model_executable.model.config
|
166 | 167 | num_heads = int(model_config.num_key_value_heads / self.tp_size)
|
167 | 168 | hidden_size = model_config.hidden_size
|
168 | 169 | num_attention_heads = model_config.num_attention_heads
|
169 |
| - head_size = int(hidden_size / num_attention_heads) |
| 170 | + head_size = getattr(model_config, "head_dim", |
| 171 | + int(hidden_size // num_attention_heads)) |
170 | 172 |
|
171 | 173 | # query_lens contains new KV caches that are added to vLLM.
|
172 | 174 | # so we will send them to decode instance
|
173 | 175 | # FIXME(Kuntai): This assume that all requests are prefill.
|
174 | 176 | for idx, slen in enumerate(seq_lens):
|
175 | 177 | start_pos = sum(seq_lens[:idx])
|
176 | 178 | end_pos = start_pos + slen
|
| 179 | + |
| 180 | + if start_pos >= num_prefill_tokens: |
| 181 | + # vllm/worker/model_runner.py::_prepare_model_input_tensors: |
| 182 | + # - input_tokens[:num_prefill_tokens] contains prefill tokens. |
| 183 | + # - input_tokens[num_prefill_tokens:] contains decode tokens. |
| 184 | + logger.warning("You have some decode requests while using " |
| 185 | + "SimpleConnector. Their KVCache won't be sent.") |
| 186 | + break |
| 187 | + |
177 | 188 | current_tokens = input_tokens_tensor[start_pos:end_pos]
|
178 | 189 |
|
179 | 190 | keys, values = [], []
|
@@ -236,7 +247,7 @@ def recv_kv_caches_and_hidden_states(
|
236 | 247 | # - input_tokens[num_prefill_tokens:] contains decode tokens.
|
237 | 248 | logger.warning("You should set --enable_chunked_prefill=False "
|
238 | 249 | "and --max_num_batched_tokens "
|
239 |
| - "should be equal to max_seq_len_to_capture") |
| 250 | + "should be equal to --max_seq_len_to_capture") |
240 | 251 | bypass_model_exec = False
|
241 | 252 | assert start_pos == num_prefill_tokens
|
242 | 253 | break
|
|
0 commit comments