Skip to content

Commit e8d0bbc

Browse files
WoosukKwonSzymonOzog
authored andcommitted
[V1] Move KV block hashes from Request to KVCacheManager (vllm-project#12922)
Signed-off-by: Woosuk Kwon <[email protected]> Signed-off-by: SzymonOzog <[email protected]>
1 parent 288fbcb commit e8d0bbc

File tree

4 files changed

+35
-31
lines changed

4 files changed

+35
-31
lines changed

tests/v1/core/test_prefix_caching.py

+11-10
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def test_prefill():
5151
all_token_ids = common_token_ids + unique_token_ids
5252
req0 = make_request("0", all_token_ids)
5353
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
54-
assert len(req0.kv_block_hashes) == 3
54+
assert len(manager.req_to_block_hashes[req0.request_id]) == 3
5555
assert not computed_blocks
5656
assert num_computed_tokens == 0
5757
blocks = manager.allocate_slots(req0, 55, computed_blocks)
@@ -76,7 +76,7 @@ def test_prefill():
7676
unique_token_ids = [3] * 5
7777
req1 = make_request("1", common_token_ids + unique_token_ids)
7878
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
79-
assert len(req1.kv_block_hashes) == 3
79+
assert len(manager.req_to_block_hashes[req1.request_id]) == 3
8080
assert [b.block_id for b in computed_blocks] == [0, 1, 2]
8181
assert num_computed_tokens == 3 * 16
8282
num_new_tokens = 53 - 3 * 16
@@ -107,7 +107,7 @@ def test_prefill():
107107
unique_token_ids = [3] * 6
108108
req2 = make_request("2", common_token_ids + unique_token_ids)
109109
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
110-
assert len(req2.kv_block_hashes) == 3
110+
assert len(manager.req_to_block_hashes[req2.request_id]) == 3
111111
assert [b.block_id for b in computed_blocks] == [0, 1, 2]
112112
assert num_computed_tokens == 3 * 16
113113
num_new_tokens = 53 - 3 * 16
@@ -494,10 +494,11 @@ def test_mm_prefix_caching():
494494
# Completed block should have hashes with extra keys.
495495
assert not computed_blocks
496496
assert num_computed_tokens == 0
497-
assert len(req0.kv_block_hashes) == 3
498-
assert req0.kv_block_hashes[0].extra_keys == ("aaa", )
499-
assert req0.kv_block_hashes[1].extra_keys == ("aaa", "bbb")
500-
assert req0.kv_block_hashes[2].extra_keys == ("bbb", )
497+
block_hashes = manager.req_to_block_hashes[req0.request_id]
498+
assert len(block_hashes) == 3
499+
assert block_hashes[0].extra_keys == ("aaa", )
500+
assert block_hashes[1].extra_keys == ("aaa", "bbb")
501+
assert block_hashes[2].extra_keys == ("bbb", )
501502

502503
blocks = manager.allocate_slots(req0, 59, computed_blocks)
503504
assert [b.block_id for b in blocks] == [0, 1, 2, 3, 4]
@@ -510,8 +511,8 @@ def test_mm_prefix_caching():
510511
assert new_blocks is not None and len(new_blocks) == 0
511512

512513
# The just completed block should have hashes with extra keys.
513-
assert len(req0.kv_block_hashes) == 4
514-
assert req0.kv_block_hashes[3].extra_keys == ("ccc", )
514+
assert len(block_hashes) == 4
515+
assert block_hashes[3].extra_keys == ("ccc", )
515516

516517
# Cache hit.
517518
unique_token_ids = [-1] * 7 + [200] * 5
@@ -613,7 +614,7 @@ def test_reset_prefix_cache():
613614
all_token_ids = full_block_token_ids + unique_token_ids
614615
req1 = make_request("1", all_token_ids)
615616
computed_blocks, _ = manager.get_computed_blocks(req1)
616-
assert len(req1.kv_block_hashes) == 3
617+
assert len(manager.req_to_block_hashes[req1.request_id]) == 3
617618
assert len(computed_blocks) == 3
618619
blocks = manager.allocate_slots(req1, 7, computed_blocks)
619620
assert [b.block_id for b in blocks] == [4]

vllm/v1/core/kv_cache_manager.py

+23-8
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,12 @@ def __init__(
7272
self.req_to_blocks: DefaultDict[str,
7373
List[KVCacheBlock]] = defaultdict(list)
7474

75+
# Mapping from request ID to kv block hashes.
76+
# This is to avoid recomputing the block hashes for each call of
77+
# `get_computed_blocks` or `allocate_slots`.
78+
self.req_to_block_hashes: DefaultDict[
79+
str, List[BlockHashType]] = defaultdict(list)
80+
7581
@property
7682
def usage(self) -> float:
7783
return 1.0 - (self.free_block_queue.num_free_blocks /
@@ -97,11 +103,11 @@ def get_computed_blocks(
97103
computed_blocks = []
98104

99105
# The block hashes for the request may already be computed
100-
# if the request was preempted and resumed.
101-
if not request.kv_block_hashes:
102-
request.set_kv_block_hashes(
103-
hash_request_tokens(self.block_size, request))
104-
block_hashes = request.kv_block_hashes
106+
# if the scheduler has tried to schedule the request before.
107+
block_hashes = self.req_to_block_hashes[request.request_id]
108+
if not block_hashes:
109+
block_hashes = hash_request_tokens(self.block_size, request)
110+
self.req_to_block_hashes[request.request_id] = block_hashes
105111

106112
for block_hash in block_hashes:
107113
# block_hashes is a chain of block hashes. If a block hash is not
@@ -435,7 +441,8 @@ def _cache_full_blocks(
435441
full_blocks: The list of blocks to update hash metadata.
436442
prev_block: The previous block in the chain.
437443
"""
438-
num_cached_block_hashes = len(request.kv_block_hashes)
444+
block_hashes = self.req_to_block_hashes[request.request_id]
445+
num_cached_block_hashes = len(block_hashes)
439446

440447
# Update the new blocks with the block hashes through the chain.
441448
prev_block_hash_value = None
@@ -468,7 +475,7 @@ def _cache_full_blocks(
468475
# this request (either the prompt tokens or the previously
469476
# generated tokens with preemption). In this case we simply
470477
# reuse the block hash.
471-
block_hash = request.kv_block_hashes[blk_idx]
478+
block_hash = block_hashes[blk_idx]
472479
else:
473480
# Otherwise compute the block hash and cache it in the request
474481
# in case it will be preempted in the future.
@@ -490,9 +497,17 @@ def _cache_full_blocks(
490497
# Compute the hash of the current block.
491498
block_hash = hash_block_tokens(prev_block_hash_value,
492499
block_tokens, extra_keys)
493-
request.append_kv_block_hashes(block_hash)
500+
block_hashes.append(block_hash)
494501

495502
# Update and added the full block to the cache.
496503
blk.block_hash = block_hash
497504
self.cached_block_hash_to_block[block_hash][blk.block_id] = blk
498505
prev_block_hash_value = block_hash.hash_value
506+
507+
def free_block_hashes(self, request: Request) -> None:
508+
"""Discard the block hashes for the request.
509+
510+
NOTE: Unlike `free`, this method should be called only when the request
511+
is finished, not when it is preempted.
512+
"""
513+
self.req_to_block_hashes.pop(request.request_id, None)

vllm/v1/core/scheduler.py

+1
Original file line numberDiff line numberDiff line change
@@ -579,6 +579,7 @@ def finish_requests(
579579
def _free_request(self, request: Request) -> None:
580580
assert request.is_finished()
581581
self.kv_cache_manager.free(request)
582+
self.kv_cache_manager.free_block_hashes(request)
582583
self.encoder_cache_manager.free(request)
583584
self._cached_reqs_data.pop(request.request_id, None)
584585
del self.requests[request.request_id]

vllm/v1/request.py

-13
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
if TYPE_CHECKING:
1313
from vllm.multimodal import MultiModalKwargs
1414
from vllm.multimodal.inputs import PlaceholderRange
15-
from vllm.v1.core.kv_cache_utils import BlockHashType
1615

1716

1817
class Request:
@@ -63,11 +62,6 @@ def __init__(
6362
if self.mm_hashes:
6463
assert len(self.mm_inputs) == len(self.mm_hashes)
6564

66-
# Cache the computed kv block hashes of the request to avoid
67-
# recomputing.
68-
self._kv_block_hashes: List[BlockHashType] = []
69-
self.kv_block_hashes = ConstantList(self._kv_block_hashes)
70-
7165
# Read-only views
7266
# Prevent directly appending to the these lists since
7367
# they should also be updated simultaneously.
@@ -124,13 +118,6 @@ def get_num_encoder_tokens(self, input_id: int) -> int:
124118
num_tokens = self.mm_positions[input_id]["length"]
125119
return num_tokens
126120

127-
def set_kv_block_hashes(self, value: List["BlockHashType"]) -> None:
128-
self._kv_block_hashes = value
129-
self.kv_block_hashes = ConstantList(self._kv_block_hashes)
130-
131-
def append_kv_block_hashes(self, block_hash: "BlockHashType") -> None:
132-
self._kv_block_hashes.append(block_hash)
133-
134121

135122
class RequestStatus(enum.IntEnum):
136123
"""Status of a request."""

0 commit comments

Comments
 (0)