Skip to content

[V1] Move KV block hashes from Request to KVCacheManager #12922

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Feb 8, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 11 additions & 10 deletions tests/v1/core/test_prefix_caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def test_prefill():
all_token_ids = common_token_ids + unique_token_ids
req0 = make_request("0", all_token_ids)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
assert len(req0.kv_block_hashes) == 3
assert len(manager.req_to_block_hashes[req0.request_id]) == 3
assert not computed_blocks
assert num_computed_tokens == 0
blocks = manager.allocate_slots(req0, 55, computed_blocks)
Expand All @@ -76,7 +76,7 @@ def test_prefill():
unique_token_ids = [3] * 5
req1 = make_request("1", common_token_ids + unique_token_ids)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
assert len(req1.kv_block_hashes) == 3
assert len(manager.req_to_block_hashes[req1.request_id]) == 3
assert [b.block_id for b in computed_blocks] == [0, 1, 2]
assert num_computed_tokens == 3 * 16
num_new_tokens = 53 - 3 * 16
Expand Down Expand Up @@ -107,7 +107,7 @@ def test_prefill():
unique_token_ids = [3] * 6
req2 = make_request("2", common_token_ids + unique_token_ids)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
assert len(req2.kv_block_hashes) == 3
assert len(manager.req_to_block_hashes[req2.request_id]) == 3
assert [b.block_id for b in computed_blocks] == [0, 1, 2]
assert num_computed_tokens == 3 * 16
num_new_tokens = 53 - 3 * 16
Expand Down Expand Up @@ -494,10 +494,11 @@ def test_mm_prefix_caching():
# Completed block should have hashes with extra keys.
assert not computed_blocks
assert num_computed_tokens == 0
assert len(req0.kv_block_hashes) == 3
assert req0.kv_block_hashes[0].extra_keys == ("aaa", )
assert req0.kv_block_hashes[1].extra_keys == ("aaa", "bbb")
assert req0.kv_block_hashes[2].extra_keys == ("bbb", )
block_hashes = manager.req_to_block_hashes[req0.request_id]
assert len(block_hashes) == 3
assert block_hashes[0].extra_keys == ("aaa", )
assert block_hashes[1].extra_keys == ("aaa", "bbb")
assert block_hashes[2].extra_keys == ("bbb", )

blocks = manager.allocate_slots(req0, 59, computed_blocks)
assert [b.block_id for b in blocks] == [0, 1, 2, 3, 4]
Expand All @@ -510,8 +511,8 @@ def test_mm_prefix_caching():
assert new_blocks is not None and len(new_blocks) == 0

# The just completed block should have hashes with extra keys.
assert len(req0.kv_block_hashes) == 4
assert req0.kv_block_hashes[3].extra_keys == ("ccc", )
assert len(block_hashes) == 4
assert block_hashes[3].extra_keys == ("ccc", )

# Cache hit.
unique_token_ids = [-1] * 7 + [200] * 5
Expand Down Expand Up @@ -613,7 +614,7 @@ def test_reset_prefix_cache():
all_token_ids = full_block_token_ids + unique_token_ids
req1 = make_request("1", all_token_ids)
computed_blocks, _ = manager.get_computed_blocks(req1)
assert len(req1.kv_block_hashes) == 3
assert len(manager.req_to_block_hashes[req1.request_id]) == 3
assert len(computed_blocks) == 3
blocks = manager.allocate_slots(req1, 7, computed_blocks)
assert [b.block_id for b in blocks] == [4]
Expand Down
31 changes: 23 additions & 8 deletions vllm/v1/core/kv_cache_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,12 @@ def __init__(
self.req_to_blocks: DefaultDict[str,
List[KVCacheBlock]] = defaultdict(list)

# Mapping from request ID to kv block hashes.
# This is to avoid recomputing the block hashes for each call of
# `get_computed_blocks` or `allocate_slots`.
self.req_to_block_hashes: DefaultDict[
str, List[BlockHashType]] = defaultdict(list)

@property
def usage(self) -> float:
return 1.0 - (self.free_block_queue.num_free_blocks /
Expand All @@ -97,11 +103,11 @@ def get_computed_blocks(
computed_blocks = []

# The block hashes for the request may already be computed
# if the request was preempted and resumed.
if not request.kv_block_hashes:
request.set_kv_block_hashes(
hash_request_tokens(self.block_size, request))
block_hashes = request.kv_block_hashes
# if the scheduler has tried to schedule the request before.
block_hashes = self.req_to_block_hashes[request.request_id]
if not block_hashes:
block_hashes = hash_request_tokens(self.block_size, request)
self.req_to_block_hashes[request.request_id] = block_hashes

for block_hash in block_hashes:
# block_hashes is a chain of block hashes. If a block hash is not
Expand Down Expand Up @@ -437,7 +443,8 @@ def _cache_full_blocks(
full_blocks: The list of blocks to update hash metadata.
prev_block: The previous block in the chain.
"""
num_cached_block_hashes = len(request.kv_block_hashes)
block_hashes = self.req_to_block_hashes[request.request_id]
num_cached_block_hashes = len(block_hashes)

# Update the new blocks with the block hashes through the chain.
prev_block_hash_value = None
Expand Down Expand Up @@ -470,7 +477,7 @@ def _cache_full_blocks(
# this request (either the prompt tokens or the previously
# generated tokens with preemption). In this case we simply
# reuse the block hash.
block_hash = request.kv_block_hashes[blk_idx]
block_hash = block_hashes[blk_idx]
else:
# Otherwise compute the block hash and cache it in the request
# in case it will be preempted in the future.
Expand All @@ -492,9 +499,17 @@ def _cache_full_blocks(
# Compute the hash of the current block.
block_hash = hash_block_tokens(prev_block_hash_value,
block_tokens, extra_keys)
request.append_kv_block_hashes(block_hash)
block_hashes.append(block_hash)

# Update and added the full block to the cache.
blk.block_hash = block_hash
self.cached_block_hash_to_block[block_hash][blk.block_id] = blk
prev_block_hash_value = block_hash.hash_value

def free_block_hashes(self, request: Request) -> None:
"""Discard the block hashes for the request.

NOTE: Unlike `free`, this method should be called only when the request
is finished, not when it is preempted.
"""
self.req_to_block_hashes.pop(request.request_id, None)
1 change: 1 addition & 0 deletions vllm/v1/core/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -579,6 +579,7 @@ def finish_requests(
def _free_request(self, request: Request) -> None:
assert request.is_finished()
self.kv_cache_manager.free(request)
self.kv_cache_manager.free_block_hashes(request)
self.encoder_cache_manager.free(request)
self._cached_reqs_data.pop(request.request_id, None)
del self.requests[request.request_id]
Expand Down
13 changes: 0 additions & 13 deletions vllm/v1/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
if TYPE_CHECKING:
from vllm.multimodal import MultiModalKwargs
from vllm.multimodal.inputs import PlaceholderRange
from vllm.v1.core.kv_cache_utils import BlockHashType


class Request:
Expand Down Expand Up @@ -63,11 +62,6 @@ def __init__(
if self.mm_hashes:
assert len(self.mm_inputs) == len(self.mm_hashes)

# Cache the computed kv block hashes of the request to avoid
# recomputing.
self._kv_block_hashes: List[BlockHashType] = []
self.kv_block_hashes = ConstantList(self._kv_block_hashes)

# Read-only views
# Prevent directly appending to the these lists since
# they should also be updated simultaneously.
Expand Down Expand Up @@ -124,13 +118,6 @@ def get_num_encoder_tokens(self, input_id: int) -> int:
num_tokens = self.mm_positions[input_id]["length"]
return num_tokens

def set_kv_block_hashes(self, value: List["BlockHashType"]) -> None:
self._kv_block_hashes = value
self.kv_block_hashes = ConstantList(self._kv_block_hashes)

def append_kv_block_hashes(self, block_hash: "BlockHashType") -> None:
self._kv_block_hashes.append(block_hash)


class RequestStatus(enum.IntEnum):
"""Status of a request."""
Expand Down