We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 0814f64 commit 04ba021Copy full SHA for 04ba021
vllm/distributed/kv_transfer/kv_connector/simple_connector.py
@@ -35,6 +35,7 @@ def __init__(
35
):
36
37
self.config = config.kv_transfer_config
38
+ self.tp_size = config.parallel_config.tensor_parallel_size
39
40
if self.config.kv_connector == "PyNcclConnector":
41
from vllm.distributed.kv_transfer.kv_pipe.pynccl_pipe import (
@@ -161,7 +162,7 @@ def send_kv_caches_and_hidden_states(
161
162
end_layer = model_executable.model.end_layer
163
164
model_config = model_executable.model.config
- num_heads = model_config.num_key_value_heads
165
+ num_heads = int(model_config.num_key_value_heads / self.tp_size)
166
hidden_size = model_config.hidden_size
167
num_attention_heads = model_config.num_attention_heads
168
head_size = int(hidden_size / num_attention_heads)
0 commit comments