Skip to content

Commit 7be1c7c

Browse files
committed
rebase and move kv cache group validation
Signed-off-by: qizixi <[email protected]>
1 parent 34d006f commit 7be1c7c

File tree

2 files changed

+17
-9
lines changed

2 files changed

+17
-9
lines changed

vllm/v1/spec_decode/eagle.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM
1515
from vllm.triton_utils import tl, triton
1616
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
17+
from vllm.v1.kv_cache_interface import KVCacheConfig
1718
from vllm.v1.sample.metadata import SamplingMetadata
1819

1920
logger = init_logger(__name__)
@@ -352,6 +353,18 @@ def dummy_run(
352353
hidden_states=self.hidden_states[:num_tokens],
353354
)
354355

356+
def validate_kv_cache_group(self, kv_cache_config: KVCacheConfig) -> None:
357+
kv_cache_groups: dict[str, int] = {}
358+
for id, kv_cache_group in enumerate(kv_cache_config.kv_cache_groups):
359+
for layer_name in kv_cache_group.layer_names:
360+
kv_cache_groups[layer_name] = id
361+
assert len(
362+
set([
363+
kv_cache_groups[layer_name]
364+
for layer_name in self.attn_layer_names
365+
])) == 1, "For multi-layer eagle draft model, "
366+
"all layers should belong to the same kv cache group"
367+
355368

356369
# NOTE(woosuk): Currently, the below code is not used and we always use argmax
357370
# to sample the draft tokens. We will use this after we find a way to manage

vllm/v1/worker/gpu_model_runner.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1890,12 +1890,10 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
18901890
self.initialize_attn_backend(kv_cache_config)
18911891

18921892
kv_caches: dict[str, torch.Tensor] = {}
1893-
kv_cache_group_ids: dict[str, int] = {}
18941893

1895-
for id, kv_cache_group in enumerate(kv_cache_config.kv_cache_groups):
1894+
for i, kv_cache_group in enumerate(kv_cache_config.kv_cache_groups):
18961895
kv_cache_spec = kv_cache_group.kv_cache_spec
18971896
for layer_name in kv_cache_group.layer_names:
1898-
kv_cache_group_ids[layer_name] = id
18991897
tensor_config = kv_cache_config.tensors[layer_name]
19001898
assert tensor_config.size % kv_cache_spec.page_size_bytes == 0
19011899
num_blocks = tensor_config.size // kv_cache_spec.page_size_bytes
@@ -1922,12 +1920,9 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
19221920

19231921
if self.speculative_config and self.speculative_config.use_eagle():
19241922
assert isinstance(self.drafter, EagleProposer)
1925-
assert len(
1926-
set([
1927-
kv_cache_group_ids[layer_name]
1928-
for layer_name in self.drafter.attn_layer_names
1929-
])) == 1, "For multi-layer eagle draft model, "
1930-
"all layers should belong to the same kv cache group"
1923+
# validate all draft model layers belong to the same kv cache
1924+
# group
1925+
self.drafter.validate_kv_cache_group(kv_cache_config)
19311926

19321927
bind_kv_cache(
19331928
kv_caches,

0 commit comments

Comments
 (0)