Skip to content

Commit 5e2d3bd

Browse files
WoosukKwonheheda12345
authored andcommitted
cherry-pick: [V1] Move KV block hashes from Request to KVCacheManager (vllm-project#12922)
Signed-off-by: Chen Zhang <[email protected]>
1 parent a7173a2 commit 5e2d3bd

File tree

4 files changed

+38
-40
lines changed

4 files changed

+38
-40
lines changed

tests/v1/core/test_prefix_caching.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ def test_prefill():
6363
all_token_ids = common_token_ids + unique_token_ids
6464
req0 = make_request("0", all_token_ids)
6565
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
66-
assert len(req0.kv_block_hashes[0]) == 3
66+
assert len(manager.req_to_block_hashes[0][req0.request_id]) == 3
6767
assert not computed_blocks[0]
6868
assert num_computed_tokens == 0
6969
blocks = manager.allocate_slots(req0, 55, computed_blocks,
@@ -89,7 +89,7 @@ def test_prefill():
8989
unique_token_ids = [3] * 5
9090
req1 = make_request("1", common_token_ids + unique_token_ids)
9191
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
92-
assert len(req1.kv_block_hashes[0]) == 3
92+
assert len(manager.req_to_block_hashes[0][req1.request_id]) == 3
9393
assert [b.block_id for b in computed_blocks[0]] == [0, 1, 2]
9494
assert num_computed_tokens == 3 * 16
9595
num_new_tokens = 53 - 3 * 16
@@ -121,7 +121,7 @@ def test_prefill():
121121
unique_token_ids = [3] * 6
122122
req2 = make_request("2", common_token_ids + unique_token_ids)
123123
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
124-
assert len(req2.kv_block_hashes[0]) == 3
124+
assert len(manager.req_to_block_hashes[0][req2.request_id]) == 3
125125
assert [b.block_id for b in computed_blocks[0]] == [0, 1, 2]
126126
assert num_computed_tokens == 3 * 16
127127
num_new_tokens = 53 - 3 * 16
@@ -509,10 +509,11 @@ def test_mm_prefix_caching():
509509
# Completed block should have hashes with extra keys.
510510
assert not computed_blocks[0]
511511
assert num_computed_tokens == 0
512-
assert len(req0.kv_block_hashes[0]) == 3
513-
assert req0.kv_block_hashes[0][0].extra_keys == ("aaa", )
514-
assert req0.kv_block_hashes[0][1].extra_keys == ("aaa", "bbb")
515-
assert req0.kv_block_hashes[0][2].extra_keys == ("bbb", )
512+
block_hashes = manager.req_to_block_hashes[req0.request_id]
513+
assert len(block_hashes[0]) == 3
514+
assert block_hashes[0][0].extra_keys == ("aaa", )
515+
assert block_hashes[0][1].extra_keys == ("aaa", "bbb")
516+
assert block_hashes[0][2].extra_keys == ("bbb", )
516517

517518
blocks = manager.allocate_slots(req0, 59, computed_blocks,
518519
num_computed_tokens)
@@ -526,8 +527,8 @@ def test_mm_prefix_caching():
526527
assert new_blocks is not None and len(new_blocks[0]) == 0
527528

528529
# The just completed block should have hashes with extra keys.
529-
assert len(req0.kv_block_hashes[0]) == 4
530-
assert req0.kv_block_hashes[0][3].extra_keys == ("ccc", )
530+
assert len(block_hashes[0]) == 4
531+
assert block_hashes[0][3].extra_keys == ("ccc", )
531532

532533
# Cache hit.
533534
unique_token_ids = [-1] * 7 + [200] * 5
@@ -632,7 +633,7 @@ def test_reset_prefix_cache():
632633
all_token_ids = full_block_token_ids + unique_token_ids
633634
req1 = make_request("1", all_token_ids)
634635
computed_blocks, _ = manager.get_computed_blocks(req1)
635-
assert len(req1.kv_block_hashes[0]) == 3
636+
assert len(manager.req_to_block_hashes[0][req1.request_id]) == 3
636637
assert len(computed_blocks[0]) == 3
637638
blocks = manager.allocate_slots(req1, 7, computed_blocks)
638639
assert [b.block_id for b in blocks[0]] == [4]

vllm/v1/core/kv_cache_manager.py

Lines changed: 26 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,13 @@ def __init__(
9898
self.req_to_blocks: DefaultDict[str, ReqKVCacheBlocks] = defaultdict(
9999
lambda: [[] for _ in range(self.num_kv_cache_groups)])
100100

101+
# Mapping from request ID to kv block hashes.
102+
# This is to avoid recomputing the block hashes for each call of
103+
# `get_computed_blocks` or `allocate_slots`.
104+
self.req_to_block_hashes: DefaultDict[
105+
str, List[List[BlockHashType]]] = defaultdict(
106+
lambda: [[] for _ in range(self.num_kv_cache_groups)])
107+
101108
@property
102109
def usage(self) -> float:
103110
return 1.0 - (self.free_block_queue.num_free_blocks /
@@ -121,17 +128,19 @@ def get_computed_blocks(self,
121128
return [[] for _ in self.managers], 0
122129

123130
# The block hashes for the request may already be computed
124-
# if the request was preempted and resumed.
125-
if not request.kv_block_hashes:
126-
request.set_kv_block_hashes([
131+
# if the scheduler has tried to schedule the request before.
132+
block_hashes = self.req_to_block_hashes[request.request_id]
133+
if not block_hashes:
134+
block_hashes = [
127135
hash_request_tokens(manager.block_size, request, i)
128136
for i, manager in enumerate(self.managers)
129-
])
137+
]
138+
self.req_to_block_hashes[request.request_id] = block_hashes
130139

131140
computed_blocks: ReqKVCacheBlocks = [] # computed blocks of each group
132141
prefix_length: List[PrefixLength] = [
133142
] # possible cached prefix length of each group
134-
block_hashes = request.kv_block_hashes
143+
135144
for i, manager in enumerate(self.managers):
136145
prefix_length_i, computed_blocks_i = (
137146
manager.get_possible_cached_prefix(block_hashes[i]))
@@ -154,7 +163,6 @@ def get_computed_blocks(self,
154163
for i, manager in enumerate(self.managers):
155164
computed_blocks[i] = computed_blocks[i][:num_computed_tokens //
156165
manager.block_size]
157-
158166
return computed_blocks, num_computed_tokens
159167

160168
def allocate_slots(
@@ -560,8 +568,8 @@ def _cache_full_blocks(
560568
prev_block: The previous block in the chain.
561569
kv_cache_group_id: The KV cache group that the blocks belong to
562570
"""
563-
num_cached_block_hashes = len(
564-
request.kv_block_hashes[kv_cache_group_id])
571+
block_hashes = self.req_to_block_hashes[request.request_id]
572+
num_cached_block_hashes = len(block_hashes[kv_cache_group_id])
565573

566574
# Update the new blocks with the block hashes through the chain.
567575
prev_block_hash_value = None
@@ -596,8 +604,7 @@ def _cache_full_blocks(
596604
# this request (either the prompt tokens or the previously
597605
# generated tokens with preemption). In this case we simply
598606
# reuse the block hash.
599-
block_hash = request.kv_block_hashes[kv_cache_group_id][
600-
blk_idx]
607+
block_hash = block_hashes[kv_cache_group_id][blk_idx]
601608
else:
602609
# Otherwise compute the block hash and cache it in the request
603610
# in case it will be preempted in the future.
@@ -620,13 +627,21 @@ def _cache_full_blocks(
620627
block_hash = hash_block_tokens(prev_block_hash_value,
621628
block_tokens, kv_cache_group_id,
622629
extra_keys)
623-
request.append_kv_block_hashes(kv_cache_group_id, block_hash)
630+
block_hashes.append(kv_cache_group_id, block_hash)
624631

625632
# Update and added the full block to the cache.
626633
blk.block_hash = block_hash
627634
self.cached_block_hash_to_block[block_hash][blk.block_id] = blk
628635
prev_block_hash_value = block_hash.hash_value
629636

637+
def free_block_hashes(self, request: Request) -> None:
638+
"""Discard the block hashes for the request.
639+
640+
NOTE: Unlike `free`, this method should be called only when the request
641+
is finished, not when it is preempted.
642+
"""
643+
self.req_to_block_hashes.pop(request.request_id, None)
644+
630645
def get_null_block(self) -> KVCacheBlock:
631646
return self._null_block
632647

vllm/v1/core/scheduler.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -551,6 +551,7 @@ def finish_requests(
551551
def _free_request(self, request: Request) -> None:
552552
assert request.is_finished()
553553
self.kv_cache_manager.free(request)
554+
self.kv_cache_manager.free_block_hashes(request)
554555
self.encoder_cache_manager.free(request)
555556
self._cached_reqs_data.pop(request.request_id, None)
556557
del self.requests[request.request_id]

vllm/v1/request.py

Lines changed: 0 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
if TYPE_CHECKING:
1313
from vllm.multimodal import MultiModalKwargs
1414
from vllm.multimodal.inputs import PlaceholderRange
15-
from vllm.v1.core.kv_cache_utils import BlockHashType
1615

1716

1817
class Request:
@@ -63,12 +62,6 @@ def __init__(
6362
if self.mm_hashes:
6463
assert len(self.mm_inputs) == len(self.mm_hashes)
6564

66-
# Cache the computed kv block hashes of the request to avoid
67-
# recomputing.
68-
self._kv_block_hashes: List[List[BlockHashType]] = []
69-
self.kv_block_hashes = ConstantList(
70-
[ConstantList(x) for x in self._kv_block_hashes])
71-
7265
# Read-only views
7366
# Prevent directly appending to the these lists since
7467
# they should also be updated simultaneously.
@@ -125,18 +118,6 @@ def get_num_encoder_tokens(self, input_id: int) -> int:
125118
num_tokens = self.mm_positions[input_id]["length"]
126119
return num_tokens
127120

128-
def set_kv_block_hashes(self, value: List[List["BlockHashType"]]) -> None:
129-
self._kv_block_hashes = value
130-
# NOTE: self.kv_block_hashes._x is not self._kv_block_hashes, but
131-
# self.kv_block_hashes[0]._x is self._kv_block_hashes[0]. This is
132-
# correct because we never need to update the outer list.
133-
self.kv_block_hashes = ConstantList(
134-
[ConstantList(x) for x in self._kv_block_hashes])
135-
136-
def append_kv_block_hashes(self, group_id: int,
137-
block_hash: "BlockHashType") -> None:
138-
self._kv_block_hashes[group_id].append(block_hash)
139-
140121

141122
class RequestStatus(enum.IntEnum):
142123
"""Status of a request."""

0 commit comments

Comments
 (0)