Skip to content

Commit 05c73a8

Browse files
heheda12345GWS0428
authored andcommitted
[V1][Prefix Cache] Move the logic of num_computed_tokens into KVCacheManager (vllm-project#12003)
1 parent c8f09cf commit 05c73a8

File tree

3 files changed

+61
-35
lines changed

3 files changed

+61
-35
lines changed

tests/v1/core/test_prefix_caching.py

Lines changed: 47 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -49,9 +49,10 @@ def test_prefill():
4949
unique_token_ids = [3] * 7
5050
all_token_ids = common_token_ids + unique_token_ids
5151
req0 = make_request("0", all_token_ids)
52-
computed_blocks = manager.get_computed_blocks(req0)
52+
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
5353
assert len(req0.kv_block_hashes) == 3
5454
assert not computed_blocks
55+
assert num_computed_tokens == 0
5556
blocks = manager.allocate_slots(req0, 55, computed_blocks)
5657
assert [b.block_id for b in blocks] == [0, 1, 2, 3, 4]
5758

@@ -73,9 +74,10 @@ def test_prefill():
7374
# Incomplete 1 block (5 tokens)
7475
unique_token_ids = [3] * 5
7576
req1 = make_request("1", common_token_ids + unique_token_ids)
76-
computed_blocks = manager.get_computed_blocks(req1)
77+
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
7778
assert len(req1.kv_block_hashes) == 3
7879
assert [b.block_id for b in computed_blocks] == [0, 1, 2]
80+
assert num_computed_tokens == 3 * 16
7981
num_new_tokens = 53 - 3 * 16
8082
blocks = manager.allocate_slots(req1, num_new_tokens, computed_blocks)
8183
assert [b.block_id for b in blocks] == [5, 6]
@@ -91,7 +93,7 @@ def test_prefill():
9193
# All blocks should be available.
9294
assert manager.free_block_queue.num_free_blocks == 10
9395
# The order should be
94-
# [unallocated (7, 8)]
96+
# [unallocated (7, 8, 9)]
9597
# [unique_req0 (4, 3)]
9698
# [unique_req1 (6, 5)]
9799
# [common (2, 1, 0)]
@@ -103,9 +105,10 @@ def test_prefill():
103105
# Incomplete 1 block (6 tokens)
104106
unique_token_ids = [3] * 6
105107
req2 = make_request("2", common_token_ids + unique_token_ids)
106-
computed_blocks = manager.get_computed_blocks(req2)
108+
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
107109
assert len(req2.kv_block_hashes) == 3
108110
assert [b.block_id for b in computed_blocks] == [0, 1, 2]
111+
assert num_computed_tokens == 3 * 16
109112
num_new_tokens = 53 - 3 * 16
110113
blocks = manager.allocate_slots(req2, num_new_tokens, computed_blocks)
111114
assert [b.block_id for b in blocks] == [7, 8]
@@ -123,8 +126,9 @@ def test_prefill():
123126

124127
# Cache miss and eviction.
125128
req3 = make_request("3", [99] * (16 * 9))
126-
computed_blocks = manager.get_computed_blocks(req3)
129+
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req3)
127130
assert not computed_blocks
131+
assert num_computed_tokens == 0
128132
blocks = manager.allocate_slots(req3, 16 * 9, computed_blocks)
129133
# This block ID order also checks the eviction order.
130134
assert [b.block_id for b in blocks] == [9, 4, 3, 6, 5, 8, 7, 2, 1, 0]
@@ -150,8 +154,9 @@ def test_decode():
150154
# Incomplete 1 block (7 tokens)
151155
unique_token_ids = [3] * 7
152156
req0 = make_request("0", common_token_ids + unique_token_ids)
153-
computed_blocks = manager.get_computed_blocks(req0)
157+
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
154158
assert not computed_blocks
159+
assert num_computed_tokens == 0
155160
blocks = manager.allocate_slots(req0, 55, computed_blocks)
156161
assert [b.block_id for b in blocks] == [0, 1, 2, 3, 4]
157162

@@ -197,16 +202,18 @@ def test_evict():
197202

198203
last_token_id = 5 * 16 + 7
199204
req0 = make_request("0", list(range(last_token_id)))
200-
computed_blocks = manager.get_computed_blocks(req0)
205+
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
201206
assert not computed_blocks
207+
assert num_computed_tokens == 0
202208
blocks = manager.allocate_slots(req0, 5 * 16 + 7, computed_blocks)
203209
assert len(blocks) == 7 # 5 full + 1 partial + 1 preallocated
204210

205211
# 3 blocks.
206212
req1 = make_request("1", list(range(last_token_id,
207213
last_token_id + 3 * 16)))
208-
computed_blocks = manager.get_computed_blocks(req1)
214+
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
209215
assert not computed_blocks
216+
assert num_computed_tokens == 0
210217
blocks = manager.allocate_slots(req1, 3 * 16, computed_blocks)
211218
assert len(blocks) == 3 # 3 full blocks
212219
last_token_id += 3 * 16
@@ -222,8 +229,9 @@ def test_evict():
222229

223230
# Touch the first 2 blocks.
224231
req2 = make_request("2", list(range(2 * 16 + 3)))
225-
computed_blocks = manager.get_computed_blocks(req2)
232+
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
226233
assert [b.block_id for b in computed_blocks] == [0, 1]
234+
assert num_computed_tokens == 2 * 16
227235
blocks = manager.allocate_slots(req2, 3, computed_blocks)
228236
assert [b.block_id for b in blocks] == [6, 5]
229237
assert manager.free_block_queue.num_free_blocks == 6
@@ -247,8 +255,9 @@ def test_hash_block_correct_reuse():
247255
# Allocate 1 block and cache it.
248256
num_tokens = block_size * 1
249257
req = make_request("0", list(range(num_tokens)))
250-
computed_blocks = manager.get_computed_blocks(req)
258+
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req)
251259
assert not computed_blocks
260+
assert num_computed_tokens == 0
252261
blocks = manager.allocate_slots(req, num_tokens, computed_blocks)
253262
assert len(blocks) == 1
254263

@@ -258,8 +267,9 @@ def test_hash_block_correct_reuse():
258267
# Allocate a new block that's not full, make sure hash info on the
259268
# block is cleared.
260269
req = make_request("1", list(range(num_tokens - 1)))
261-
computed_blocks = manager.get_computed_blocks(req)
270+
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req)
262271
assert not computed_blocks
272+
assert num_computed_tokens == 0
263273
blocks = manager.allocate_slots(req, num_tokens - 1, computed_blocks)
264274
assert len(blocks) == 1
265275

@@ -284,16 +294,18 @@ def test_computed_blocks_not_evicted():
284294
# Allocate a block and cache it.
285295
num_tokens = block_size * 1
286296
req0 = make_request("0", list(range(num_tokens)))
287-
computed_blocks = manager.get_computed_blocks(req0)
297+
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
288298
assert not computed_blocks
299+
assert num_computed_tokens == 0
289300
blocks = manager.allocate_slots(req0, num_tokens, computed_blocks)
290301
assert len(blocks) == 1
291302
assert blocks[0].block_id == 0
292303

293304
# Allocate another block.
294305
req1 = make_request("1", list(range(num_tokens, num_tokens * 2)))
295-
computed_blocks = manager.get_computed_blocks(req1)
306+
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
296307
assert not computed_blocks
308+
assert num_computed_tokens == 0
297309
blocks = manager.allocate_slots(req1, num_tokens, computed_blocks)
298310
assert len(blocks) == 1
299311
assert blocks[0].block_id == 1
@@ -305,9 +317,10 @@ def test_computed_blocks_not_evicted():
305317
# Now if we have a cache hit on the first block, we should evict the second
306318
# cached block rather than the first one.
307319
req2 = make_request("2", list(range(num_tokens * 2)))
308-
computed_blocks = manager.get_computed_blocks(req2)
320+
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
309321
assert len(computed_blocks) == 1
310322
assert computed_blocks[0].block_id == 0
323+
assert num_computed_tokens == block_size
311324

312325
blocks = manager.allocate_slots(req2, num_tokens * 2 - num_tokens,
313326
computed_blocks)
@@ -331,8 +344,9 @@ def test_basic_prefix_caching_disabled():
331344

332345
req1 = make_request("1", list(range(10))) # 2 blocks and some more
333346

334-
computed_blocks = manager.get_computed_blocks(req1)
347+
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
335348
assert not computed_blocks
349+
assert num_computed_tokens == 0
336350
blocks = manager.allocate_slots(req1, 10, computed_blocks)
337351
assert len(blocks) == 3
338352

@@ -341,15 +355,17 @@ def test_basic_prefix_caching_disabled():
341355

342356
# No caching.
343357
req2 = make_request("2", list(range(16))) # shared prefix
344-
computed_blocks = manager.get_computed_blocks(req2)
358+
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
345359
assert not computed_blocks
360+
assert num_computed_tokens == 0
346361
blocks = manager.allocate_slots(req2, 16, computed_blocks)
347362
assert len(blocks) == 4
348363

349364
# New requests should not have any blocks.
350365
req3 = make_request("3", list(range(4)))
351-
computed_blocks = manager.get_computed_blocks(req3)
366+
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req3)
352367
assert not computed_blocks
368+
assert num_computed_tokens == 0
353369
blocks = manager.allocate_slots(req3, 4, computed_blocks)
354370
assert not blocks
355371

@@ -371,8 +387,9 @@ def test_preallocate_blocks(num_preallocate_tokens: int, block_size: int):
371387
num_preallocated_blocks = cdiv(num_preallocate_tokens, block_size)
372388

373389
req = make_request("0", list(range(block_size * 30)))
374-
computed_blocks = manager.get_computed_blocks(req)
390+
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req)
375391
assert not computed_blocks
392+
assert num_computed_tokens == 0
376393
# Just ask for 1 block.
377394
blocks = manager.allocate_slots(req, block_size, computed_blocks)
378395
req.num_computed_tokens = block_size
@@ -469,10 +486,11 @@ def test_mm_prefix_caching():
469486
all_token_ids,
470487
mm_positions=mm_positions,
471488
mm_hashes=mm_hashes)
472-
computed_blocks = manager.get_computed_blocks(req0)
489+
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
473490

474491
# Completed block should have hashes with extra keys.
475492
assert not computed_blocks
493+
assert num_computed_tokens == 0
476494
assert len(req0.kv_block_hashes) == 3
477495
assert req0.kv_block_hashes[0].extra_keys == ("aaa", )
478496
assert req0.kv_block_hashes[1].extra_keys == ("aaa", "bbb")
@@ -503,8 +521,9 @@ def test_mm_prefix_caching():
503521
all_token_ids,
504522
mm_positions=mm_positions,
505523
mm_hashes=mm_hashes)
506-
computed_blocks = manager.get_computed_blocks(req1)
524+
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
507525
assert len(computed_blocks) == 3
526+
assert num_computed_tokens == 3 * 16
508527

509528

510529
def test_prefill_not_enough_free_blocks_with_computed_blocks():
@@ -527,15 +546,17 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks():
527546
# | Common-0 | Common-1 | Common-2 | ... |
528547
common_token_ids = [i for i in range(3) for _ in range(16)]
529548
req0 = make_request("0", common_token_ids)
530-
computed_blocks = manager.get_computed_blocks(req0)
549+
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
531550
assert not computed_blocks
551+
assert num_computed_tokens == 0
532552
manager.allocate_slots(req0, 48, computed_blocks)
533553
block_part0 = manager.req_to_blocks[req0.request_id]
534554

535555
# | Common-0 | Common-1 | Common-2 | Req1-3 | Req1-4 | Req1-5 | ... |
536556
req1 = make_request("1", common_token_ids * 2)
537-
computed_blocks = manager.get_computed_blocks(req1)
557+
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
538558
assert computed_blocks == block_part0
559+
assert num_computed_tokens == 3 * 16
539560
manager.allocate_slots(req1, 48, computed_blocks)
540561
block_part1 = manager.req_to_blocks[req1.request_id]
541562
# | Common-0 | Common-1 | Common-2 | Req1-3 (F) | Req1-4 (F) |
@@ -547,17 +568,19 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks():
547568
# | Common-0 | Common-1 | Common-2 | Req1-3 (F) | Req1-4 (F) |
548569
# | Req1-5(F)| Req2-0 | Req2-1 | ... |
549570
req2 = make_request("2", [7] * block_size * 2)
550-
computed_blocks = manager.get_computed_blocks(req2)
571+
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
551572
assert not computed_blocks
573+
assert num_computed_tokens == 0
552574
manager.allocate_slots(req2, block_size * 2, computed_blocks)
553575

554576
# Req3 is Req2 + 3 new blocks, so the first 6 blocks are computed,
555577
# but it cannot be allocated due to insufficient free blocks (2).
556578
# In this case, the ref_cnt of the computed blocks should not be changed.
557579
assert manager.free_block_queue.num_free_blocks == 5
558580
req3 = make_request("3", common_token_ids * 3)
559-
computed_blocks = manager.get_computed_blocks(req3)
581+
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req3)
560582
assert computed_blocks == block_part1
583+
assert num_computed_tokens == 6 * 16
561584
# Req3 cannot be allocated.
562585
assert manager.allocate_slots(req3, 48, computed_blocks) is None
563586
# Block 0-2 are used by Req 1.

vllm/v1/core/kv_cache_manager.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from collections import defaultdict
2-
from typing import Dict, Iterable, List, Optional
2+
from typing import Dict, Iterable, List, Optional, Tuple
33

44
from vllm.logger import init_logger
55
from vllm.utils import cdiv
@@ -69,19 +69,22 @@ def __init__(
6969
# is finished.
7070
self.req_to_blocks: Dict[str, List[KVCacheBlock]] = {}
7171

72-
def get_computed_blocks(self, request: Request) -> List[KVCacheBlock]:
72+
def get_computed_blocks(
73+
self, request: Request) -> Tuple[List[KVCacheBlock], int]:
7374
"""Get the computed (cached) blocks for the request.
7475
Note that the computed blocks must be full.
7576
7677
Args:
7778
request: The request to get the computed blocks.
7879
7980
Returns:
80-
A list of blocks that are computed for the request.
81+
A tuple containing:
82+
- A list of blocks that are computed for the request.
83+
- The number of computed tokens.
8184
"""
8285
if not self.enable_caching:
8386
# Prefix caching is disabled.
84-
return []
87+
return [], 0
8588

8689
computed_blocks = []
8790

@@ -101,7 +104,11 @@ def get_computed_blocks(self, request: Request) -> List[KVCacheBlock]:
101104
else:
102105
break
103106

104-
return computed_blocks
107+
# NOTE(woosuk): Since incomplete blocks are not eligible for
108+
# sharing, `num_computed_tokens` is always a multiple of
109+
# `block_size`.
110+
num_computed_tokens = len(computed_blocks) * self.block_size
111+
return computed_blocks, num_computed_tokens
105112

106113
def append_slots(
107114
self,

vllm/v1/core/scheduler.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -184,12 +184,8 @@ def schedule(self) -> "SchedulerOutput":
184184

185185
request = self.waiting[0]
186186
# Get already-cached tokens.
187-
computed_blocks = self.kv_cache_manager.get_computed_blocks(
188-
request)
189-
# NOTE(woosuk): Since incomplete blocks are not eligible for
190-
# sharing, `num_computed_tokens` is always a multiple of
191-
# `block_size`.
192-
num_computed_tokens = len(computed_blocks) * self.block_size
187+
computed_blocks, num_computed_tokens = \
188+
self.kv_cache_manager.get_computed_blocks(request)
193189
# Number of tokens to be scheduled.
194190
# We use `request.num_tokens` instead of
195191
# `request.num_prompt_tokens` to consider the resumed requests,

0 commit comments

Comments
 (0)