Skip to content

Commit a12b247

Browse files
committed
rebase and move kv cache group validation
Signed-off-by: qizixi <[email protected]>
1 parent 34d006f commit a12b247

File tree

4 files changed

+54
-9
lines changed

4 files changed

+54
-9
lines changed

profile_client.sh

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
export LLAMA_DIR=meta-llama/Llama-3.1-8B-Instruct
2+
3+
python benchmarks/benchmark_serving.py \
4+
--backend vllm \
5+
--model $LLAMA_DIR \
6+
--dataset-name random \
7+
--random-input-len 2000 \
8+
--random-output-len 150 \
9+
--max-concurrency 2 \
10+
--num-prompts 2 \
11+
--profile \
12+
2>&1 | tee ./base_client_$(date +%Y%m%d_%H%M%S).log

server.sh

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
#!/bin/bash
2+
# Configuration of environment variables
3+
export VLLM_USE_MODELSCOPE=False
4+
export VLLM_TORCH_PROFILER_DIR=~/vllm_profile
5+
export LLAMA_DIR=meta-llama/Llama-3.1-8B-Instruct
6+
export VLLM_USE_V1=1
7+
# Eagle config
8+
spec_dec_config='{"method": "eagle", "model": "yuhuili/EAGLE-LLaMA3.1-Instruct-8B", "num_speculative_tokens": 3, "draft_tensor_parallel_size": 1, "max_model_len": 2048}'
9+
# Command to run the vllm server
10+
vllm serve $LLAMA_DIR --disable-log-requests \
11+
-tp 1 \
12+
--max-num-seqs 128 \
13+
--max_num_batched_tokens=8000 \
14+
--speculative-config="$spec_dec_config" \
15+
--num-lookahead-slots=3 \
16+
--max-model-len=8192 \
17+
--enable-prefix-caching \
18+
--trust-remote-code \
19+
2>&1 | tee /data/users/$USER/logs/server/vllm_17b16e_vllm_serving$(date +%Y%m%d_%H%M%S).log

vllm/v1/spec_decode/eagle.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM
1515
from vllm.triton_utils import tl, triton
1616
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
17+
from vllm.v1.kv_cache_interface import KVCacheConfig
1718
from vllm.v1.sample.metadata import SamplingMetadata
1819

1920
logger = init_logger(__name__)
@@ -352,6 +353,24 @@ def dummy_run(
352353
hidden_states=self.hidden_states[:num_tokens],
353354
)
354355

356+
def validate_kv_cache_group(self, kv_cache_config: KVCacheConfig) -> None:
357+
"""
358+
Validate that all eagle layers belong to the same KVCacheGroup.
359+
Need this assumption to ensure all eagle layers can use the
360+
same AttentionMetadata.
361+
May extend to multiple AttentionMetadata in the future.
362+
"""
363+
kv_cache_groups: dict[str, int] = {}
364+
for id, kv_cache_group in enumerate(kv_cache_config.kv_cache_groups):
365+
for layer_name in kv_cache_group.layer_names:
366+
kv_cache_groups[layer_name] = id
367+
assert len(
368+
set([
369+
kv_cache_groups[layer_name]
370+
for layer_name in self.attn_layer_names
371+
])
372+
) == 1, "All eagle layers should belong to the same kv cache group"
373+
355374

356375
# NOTE(woosuk): Currently, the below code is not used and we always use argmax
357376
# to sample the draft tokens. We will use this after we find a way to manage

vllm/v1/worker/gpu_model_runner.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1890,12 +1890,10 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
18901890
self.initialize_attn_backend(kv_cache_config)
18911891

18921892
kv_caches: dict[str, torch.Tensor] = {}
1893-
kv_cache_group_ids: dict[str, int] = {}
18941893

1895-
for id, kv_cache_group in enumerate(kv_cache_config.kv_cache_groups):
1894+
for i, kv_cache_group in enumerate(kv_cache_config.kv_cache_groups):
18961895
kv_cache_spec = kv_cache_group.kv_cache_spec
18971896
for layer_name in kv_cache_group.layer_names:
1898-
kv_cache_group_ids[layer_name] = id
18991897
tensor_config = kv_cache_config.tensors[layer_name]
19001898
assert tensor_config.size % kv_cache_spec.page_size_bytes == 0
19011899
num_blocks = tensor_config.size // kv_cache_spec.page_size_bytes
@@ -1922,12 +1920,9 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
19221920

19231921
if self.speculative_config and self.speculative_config.use_eagle():
19241922
assert isinstance(self.drafter, EagleProposer)
1925-
assert len(
1926-
set([
1927-
kv_cache_group_ids[layer_name]
1928-
for layer_name in self.drafter.attn_layer_names
1929-
])) == 1, "For multi-layer eagle draft model, "
1930-
"all layers should belong to the same kv cache group"
1923+
# validate all draft model layers belong to the same kv cache
1924+
# group
1925+
self.drafter.validate_kv_cache_group(kv_cache_config)
19311926

19321927
bind_kv_cache(
19331928
kv_caches,

0 commit comments

Comments
 (0)