Skip to content

[V1][Spec Decode] KV cache slots for eagle heads #16370

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Apr 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 74 additions & 12 deletions tests/v1/core/test_kv_cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
from vllm.sampling_params import SamplingParams
from vllm.utils import GiB_bytes, sha256
from vllm.v1.core.kv_cache_manager import KVCacheManager
# disable yapf here as it formats differently than isort such that both fail
# yapf: disable
from vllm.v1.core.kv_cache_utils import (NONE_HASH, BlockHashType,
Expand Down Expand Up @@ -48,6 +49,18 @@ def make_request(request_id,
)


def new_kv_cache_spec(block_size=16,
num_kv_heads=2,
head_size=64,
dtype=torch.float32,
use_mla=False):
return FullAttentionSpec(block_size=block_size,
num_kv_heads=num_kv_heads,
head_size=head_size,
dtype=dtype,
use_mla=use_mla)


def test_none_hash():
assert NONE_HASH is not None
assert isinstance(NONE_HASH, int)
Expand Down Expand Up @@ -327,18 +340,6 @@ def stats(requests, queries, hits):


def test_unify_kv_cache_configs():

def new_kv_cache_spec(block_size=16,
num_kv_heads=2,
head_size=64,
dtype=torch.float32,
use_mla=False):
return FullAttentionSpec(block_size=block_size,
num_kv_heads=num_kv_heads,
head_size=head_size,
dtype=dtype,
use_mla=use_mla)

same_kv_cache_config = [
KVCacheConfig(
num_blocks=10,
Expand Down Expand Up @@ -470,3 +471,64 @@ def test_estimate_max_model_len(model_id, max_model_len,
estimated_max_len = estimate_max_model_len(vllm_config, kv_cache_spec,
8 * GiB_bytes)
assert estimated_max_len == want_estimated_max_len


def test_allocate_with_lookahead():
"""Verify that lookahead tokens correctly affect block allocation"""
block_size = 4
config = KVCacheConfig(
num_blocks=10,
tensors={
"layer1": KVCacheTensor(100),
},
kv_cache_groups=[
KVCacheGroupSpec(["layer1"],
new_kv_cache_spec(block_size=block_size)),
],
)

request = make_request(
request_id=0,
prompt_token_ids=[],
mm_positions=None,
mm_hashes=None,
)

# Test case 1: Requires additional lookahead tokens
kv_cache_manager = KVCacheManager(kv_cache_config=config,
max_model_len=100,
num_preallocate_tokens=0)
blocks = kv_cache_manager.allocate_slots(
request,
num_tokens=3,
num_lookahead_tokens=2, # Total required: 3+2=5 tokens
)
assert len(blocks) == 2 # ceil(5/4)=2 blocks

# Test case 2: With precomputed blocks
kv_cache_manager = KVCacheManager(kv_cache_config=config,
max_model_len=100,
num_preallocate_tokens=4)
# num_preallocate_blocks = 4 // 4 - 2 // 4 = 1
# required_blocks = ceil((3 + 2) /4) = 2
# total_blocks = 1 + 2 = 3
blocks = kv_cache_manager.allocate_slots(
request,
num_tokens=3,
num_lookahead_tokens=2,
)
assert len(blocks) == 3

# Test case 3: With precomputed blocks
# num_preallocate_blocks = 4 // 4 - 4 // 4 = 0
# required_blocks = ceil((3 + 4) / 4) = 2
# total_blocks = 0 + 2 = 2
kv_cache_manager = KVCacheManager(kv_cache_config=config,
max_model_len=100,
num_preallocate_tokens=4)
blocks = kv_cache_manager.allocate_slots(
request,
num_tokens=3,
num_lookahead_tokens=4,
)
assert len(blocks) == 2
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems a little bit weird to the common sense. When the num_lookahead_tokens increases (which means we may need more slots for allocation), the allocated blocks decrease.

In the test case 2, the num_lookahead_tokens does not use the slots for preallocate_tokens. Is there any particular reason why the design will be like this?

Copy link
Collaborator Author

@LiuXiaoxuanPKU LiuXiaoxuanPKU Apr 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah agree, I feel it's a conner case, num_lookahead_tokens does not use slots for preallocate_tokens because we calculate on the block level, and 3 // 4 = 0 (lookahead tokens borrow 0 block from preallocate tokens).

when num_lookahead_tokens=1, 2, 3, len(blocks) = 3
when num_lookahead_tokens=4, len(blocks) = 2
when num_lookahead_tokens=5, 6, 7...., len(blocks) = ceil((num_lookahead_tokens + 4 ) / 4) = 3 or bigger

number of required blocks = num of blocks required by lookahead slots + num of blocks required by compute tokens + num of preallocate blocks
num of preallocate blocks = max(0, Constant - num of blocks required by lookahead slots)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think with num_lookahead_tokens=1,2,3, len(blocks) should stll be 2?

IIUC, the num_lookahead_token can used up the space for preallocate_tokens. So, basically before actually pre-allocate, we just need to check sure the last preallocate_blocks are already taken up by lookahead_tokens

May be the pseudocode could be something like:

num_required_blocks_before_lookahead = cdiv(
            num_computed_tokens + num_tokens,
            self.block_size)
num_required_blocks = cdiv(
            num_computed_tokens + num_tokens + num_lookahead_tokens,
            self.block_size)
num_required_blocks_used_by_lookahead = num_required_blocks - 
             num_required_blocks_before_lookahead

num_preallocate_blocks = max(
                0, self.num_preallocate_blocks -
                num_required_blocks_used_by_lookahead)
# if we find the preallocated blocks have been used up by lookahead, we don't need to further allocate them.

image

16 changes: 12 additions & 4 deletions vllm/v1/core/kv_cache_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,8 @@ def allocate_slots(
self,
request: Request,
num_tokens: int,
new_computed_blocks: Optional[list[KVCacheBlock]] = None
new_computed_blocks: Optional[list[KVCacheBlock]] = None,
num_lookahead_tokens: int = 0,
) -> Optional[list[KVCacheBlock]]:
"""Add slots for a request with new tokens to append.
Expand All @@ -174,6 +175,9 @@ def allocate_slots(
not include the tokens that have already been computed.
new_computed_blocks: A list of new computed blocks just hitting the
prefix caching.
num_lookahead_tokens: The number of speculative tokens to allocate.
This is used by spec decode proposers with kv-cache such
as eagle.
Blocks layout:
-----------------------------------------------------------------------
Expand Down Expand Up @@ -211,8 +215,9 @@ def allocate_slots(
# the new prefix caching hits
num_computed_tokens = (request.num_computed_tokens +
len(new_computed_blocks) * self.block_size)
num_required_blocks = cdiv(num_computed_tokens + num_tokens,
self.block_size)
num_required_blocks = cdiv(
num_computed_tokens + num_tokens + num_lookahead_tokens,
self.block_size)
num_new_blocks = (num_required_blocks - len(req_blocks) -
len(new_computed_blocks))

Expand Down Expand Up @@ -246,8 +251,11 @@ def allocate_slots(
else:
# Get new blocks from the free block pool considering
# preallocated blocks.
num_preallocate_blocks = max(
0, self.num_preallocate_blocks -
num_lookahead_tokens // self.block_size)
num_new_blocks = min(
num_new_blocks + self.num_preallocate_blocks,
num_new_blocks + num_preallocate_blocks,
self.block_pool.get_num_free_blocks(),
# Should not exceed the maximum number of blocks per request.
# This is especially because the block table has the shape
Expand Down
13 changes: 11 additions & 2 deletions vllm/v1/core/sched/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
from collections.abc import Iterable
from typing import Optional, Union

from vllm.config import CacheConfig, LoRAConfig, ModelConfig, SchedulerConfig
from vllm.config import (CacheConfig, LoRAConfig, ModelConfig, SchedulerConfig,
SpeculativeConfig)
from vllm.logger import init_logger
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.v1.core.encoder_cache_manager import (EncoderCacheManager,
Expand Down Expand Up @@ -39,6 +40,7 @@ def __init__(
lora_config: Optional[LoRAConfig],
kv_cache_config: KVCacheConfig,
structured_output_manager: StructuredOutputManager,
speculative_config: SpeculativeConfig = None,
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
include_finished_set: bool = False,
log_stats: bool = False,
Expand Down Expand Up @@ -112,6 +114,11 @@ def __init__(
self.encoder_cache_manager = EncoderCacheManager(
cache_size=encoder_cache_size)

self.num_lookahead_tokens = 0
if speculative_config and speculative_config.method == "eagle":
self.num_lookahead_tokens = \
speculative_config.num_speculative_tokens

def schedule(self) -> SchedulerOutput:
# NOTE(woosuk) on the scheduling algorithm:
# There's no "decoding phase" nor "prefill phase" in the scheduler.
Expand Down Expand Up @@ -188,7 +195,9 @@ def schedule(self) -> SchedulerOutput:

while True:
new_blocks = self.kv_cache_manager.allocate_slots(
request, num_new_tokens)
request,
num_new_tokens,
num_lookahead_tokens=self.num_lookahead_tokens)
if new_blocks is None:
# The request cannot be scheduled.
# Preempt the lowest-priority request.
Expand Down
1 change: 1 addition & 0 deletions vllm/v1/engine/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ def __init__(
cache_config=vllm_config.cache_config,
lora_config=vllm_config.lora_config,
kv_cache_config=kv_cache_config,
speculative_config=vllm_config.speculative_config,
structured_output_manager=self.structured_output_manager,
include_finished_set=vllm_config.parallel_config.data_parallel_size
> 1,
Expand Down