Skip to content

[V1][Spec Decode] KV cache slots for eagle heads #16370

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Apr 13, 2025
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 12 additions & 4 deletions vllm/v1/core/kv_cache_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,8 @@ def allocate_slots(
self,
request: Request,
num_tokens: int,
new_computed_blocks: Optional[list[KVCacheBlock]] = None
new_computed_blocks: Optional[list[KVCacheBlock]] = None,
num_lookahead_tokens: int = 0,
) -> Optional[list[KVCacheBlock]]:
"""Add slots for a request with new tokens to append.

Expand All @@ -174,6 +175,9 @@ def allocate_slots(
not include the tokens that have already been computed.
new_computed_blocks: A list of new computed blocks just hitting the
prefix caching.
num_lookahead_tokens: The number of speculative tokens to allocate.
This field is only used by eagle. We allocate the slots for
the propose heads.

Blocks layout:
-----------------------------------------------------------------------
Expand Down Expand Up @@ -211,8 +215,9 @@ def allocate_slots(
# the new prefix caching hits
num_computed_tokens = (request.num_computed_tokens +
len(new_computed_blocks) * self.block_size)
num_required_blocks = cdiv(num_computed_tokens + num_tokens,
self.block_size)
num_required_blocks = cdiv(
num_computed_tokens + num_tokens + num_lookahead_tokens,
self.block_size)
num_new_blocks = (num_required_blocks - len(req_blocks) -
len(new_computed_blocks))

Expand Down Expand Up @@ -246,8 +251,11 @@ def allocate_slots(
else:
# Get new blocks from the free block pool considering
# preallocated blocks.
num_preallocate_blocks = max(
0, self.num_preallocate_blocks -
num_lookahead_tokens // self.block_size)
num_new_blocks = min(
num_new_blocks + self.num_preallocate_blocks,
num_new_blocks + num_preallocate_blocks,
self.block_pool.get_num_free_blocks(),
# Should not exceed the maximum number of blocks per request.
# This is especially because the block table has the shape
Expand Down
12 changes: 10 additions & 2 deletions vllm/v1/core/sched/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
from collections.abc import Iterable
from typing import Optional, Union

from vllm.config import CacheConfig, LoRAConfig, ModelConfig, SchedulerConfig
from vllm.config import (CacheConfig, LoRAConfig, ModelConfig, SchedulerConfig,
SpeculativeConfig)
from vllm.logger import init_logger
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.v1.core.encoder_cache_manager import (EncoderCacheManager,
Expand Down Expand Up @@ -39,6 +40,7 @@ def __init__(
lora_config: Optional[LoRAConfig],
kv_cache_config: KVCacheConfig,
structured_output_manager: StructuredOutputManager,
speculative_config: SpeculativeConfig = None,
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
include_finished_set: bool = False,
log_stats: bool = False,
Expand Down Expand Up @@ -112,6 +114,10 @@ def __init__(
self.encoder_cache_manager = EncoderCacheManager(
cache_size=encoder_cache_size)

self.num_spec_tokens = 0
if speculative_config and speculative_config.method == "eagle":
self.num_spec_tokens = speculative_config.num_speculative_tokens

def schedule(self) -> SchedulerOutput:
# NOTE(woosuk) on the scheduling algorithm:
# There's no "decoding phase" nor "prefill phase" in the scheduler.
Expand Down Expand Up @@ -188,7 +194,9 @@ def schedule(self) -> SchedulerOutput:

while True:
new_blocks = self.kv_cache_manager.allocate_slots(
request, num_new_tokens)
request,
num_new_tokens,
num_lookahead_tokens=self.num_spec_tokens)
if new_blocks is None:
# The request cannot be scheduled.
# Preempt the lowest-priority request.
Expand Down
1 change: 1 addition & 0 deletions vllm/v1/engine/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ def __init__(
cache_config=vllm_config.cache_config,
lora_config=vllm_config.lora_config,
kv_cache_config=kv_cache_config,
speculative_config=vllm_config.speculative_config,
structured_output_manager=self.structured_output_manager,
include_finished_set=vllm_config.parallel_config.data_parallel_size
> 1,
Expand Down